Skip to content

Commit 5e3fd41

Browse files
committed
px->cm
1 parent 4d3484c commit 5e3fd41

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

proj/environment/trajectories.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,15 @@ def complete_given_xy(x, y, params, planning_params):
3232

3333
trajectory = np.vstack([x, y, angle, speed, ang_speed]).T
3434
return (
35-
compute_trajectory_stats(trajectory, len(trajectory), planning_params),
35+
compute_trajectory_stats(
36+
trajectory, len(trajectory), params, planning_params
37+
),
3638
None,
3739
)
3840

3941

4042
def compute_trajectory_stats(
41-
trajectory, duration, planning_params, min_dist_travelled=150
43+
trajectory, duration, params, planning_params, min_dist_travelled=150
4244
):
4345
# Compute stats
4446
n_points = len(trajectory)
@@ -110,7 +112,7 @@ def compute_trajectory_stats(
110112
extra={"markdown": True},
111113
)
112114

113-
if distance_travelled < min_dist_travelled:
115+
if distance_travelled < min_dist_travelled * params["px_to_cm"]:
114116
log.warning("Distance travelled below minimal requirement, erroring")
115117
return None, None
116118

@@ -207,14 +209,15 @@ def from_tracking(n_steps, params, planning_params, cache_fld, *args):
207209
except:
208210
fps = 60
209211

210-
x = trial.body_xy[:, 0]
211-
y = trial.body_xy[:, 1]
212+
x = trial.body_xy[:, 0] * params["px_to_cm"]
213+
y = trial.body_xy[:, 1] * params["px_to_cm"]
212214

213215
angle = interpolate_nans(trial.body_orientation)
214216
angle = np.radians(90 - angle)
215217
angle = np.unwrap(angle)
216218

217219
speed = line_smoother(trial.body_speed) * fps
220+
speed *= params["px_to_cm"]
218221
ang_speed = np.ones_like(speed) # it will be ignored
219222

220223
# get start frame
@@ -241,7 +244,7 @@ def from_tracking(n_steps, params, planning_params, cache_fld, *args):
241244

242245
return (
243246
compute_trajectory_stats(
244-
trajectory, len(x[start:]) / fps, planning_params
247+
trajectory, len(x[start:]) / fps, params, planning_params
245248
),
246249
trials,
247250
)

proj/model/config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,14 @@ class Config:
6363

6464
trajectory = dict( # parameters of the goals trajectory
6565
name="tracking",
66+
# ? For artificial trajectories
6667
nsteps=10000,
6768
distance=150,
6869
max_speed=100,
6970
min_speed=80,
7071
min_dist=5, # if agent is within this distance from trajectory end the goal is considered achieved
72+
# ? for trajectories from data
73+
px_to_cm=1 / 30.8, # convert px values to cm
7174
# dist_th=60, # keep frames only after moved away from start location
7275
dist_th=-1,
7376
resample=True, # if True when using tracking trajectory resamples it
@@ -82,7 +85,7 @@ class Config:
8285
)
8386

8487
# --------------------------------- Plotting --------------------------------- #
85-
traj_plot_every = 15
88+
traj_plot_every = 80
8689

8790
# ------------------------------ Control params ------------------------------ #
8891
iLQR = dict(

0 commit comments

Comments
 (0)