Skip to content

Commit af150f9

Browse files
committed
fix
1 parent f6518d9 commit af150f9

File tree

4 files changed

+15
-10
lines changed

4 files changed

+15
-10
lines changed

proj/environment/manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _save_video(self):
107107
animate_from_images(
108108
str(self.frames_folder),
109109
str(self.datafolder / f"{self.exp_name}.mp4"),
110-
int(round(1 / self.mouse.dt)),
110+
int(round(1 / self.model.dt)),
111111
)
112112
except (ValueError, FileNotFoundError):
113113
print("Failed to generate video from frames.. ")

proj/run/runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def render(self, task):
3838

3939
# run
4040
def run_experiment(
41-
environment, controller, model, n_secs=0.5, frames_folder=None,
41+
environment, controller, model, n_secs=8, frames_folder=None,
4242
):
4343
"""
4444
Runs an experiment

rnn/cartesian_rnn.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,15 @@ def __getitem__(self, idx):
8585
# Get a simulation folder and load data
8686
fld = random.choice(self.simulation_folders)
8787

88-
(
89-
config,
90-
trajectory,
91-
history,
92-
cost_history,
93-
) = load_results_from_folder(fld)
88+
try:
89+
(
90+
config,
91+
trajectory,
92+
history,
93+
cost_history,
94+
) = load_results_from_folder(fld)
95+
except ValueError:
96+
continue
9497

9598
# Get controls history
9699
controls = np.vstack(history[["tau_r", "tau_l"]].values)
@@ -150,7 +153,8 @@ def init_hidden(self,):
150153

151154
def forward(self, X):
152155
# Reshape X: n_steps X batch_size X n_inputs
153-
X = X.unsqueeze(0)
156+
# X = X.unsqueeze(0)
157+
X = X.permute((1, 0, 2))
154158

155159
# for each time step
156160
self.hidden = self.rnn(X, self.hidden)

run_allocentric.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
# plot_trajectory(env.reset())
2626

2727
# %%
28-
run_experiment(env, control, model)
28+
run_experiment(env, control, model, n_secs=0.1)
29+
2930

3031
# %%

0 commit comments

Comments
 (0)