Skip to content

Commit 5459406

Browse files
committed
fixed planning at end of trahectory
1 parent 6b07cd7 commit 5459406

File tree

4 files changed

+36
-17
lines changed

4 files changed

+36
-17
lines changed

proj/environment/environment.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def plan(self, curr_x, g_traj, itern):
8383
Given the current state and the goal trajectory,
8484
find the next N sates, based on planning
8585
"""
86+
87+
traj_length = len(g_traj)
8688
n_ahead = self.model.planning["n_ahead"]
8789
pred_len = self.model.planning["prediction_length"] + 1
8890

@@ -96,24 +98,35 @@ def plan(self, curr_x, g_traj, itern):
9698
self.model.curr_traj_waypoint_idx = self.curr_traj_waypoint_idx
9799
self.current_traj_waypoint = g_traj[min_idx, :]
98100

99-
start = min_idx + n_ahead
100-
if start > len(g_traj):
101-
start = len(g_traj)
101+
start = min_idx + n_ahead # don't overshoot
102+
if start > traj_length:
103+
start = traj_length
102104

103105
end = min_idx + n_ahead + pred_len
104106

105-
if start + pred_len > len(g_traj):
106-
end = len(g_traj) - 2
107+
if start + pred_len > traj_length:
108+
end = traj_length
107109

108-
if abs(start - end) != pred_len:
109-
g_traj = g_traj[start:end]
110-
len_diff = (end - start) - pred_len
110+
# Make sure planned trajectory has the correct length
111+
if (end - start) != pred_len:
112+
planned = g_traj[start:end]
113+
len_diff = len(planned) - pred_len
111114

112115
if len_diff <= 0:
113-
len_diff = 1
114-
return np.pad(g_traj, ((0, len_diff), (0, 0)), mode="edge")
116+
planned = np.pad(
117+
planned, ((0, abs(len_diff)), (0, 0)), mode="edge"
118+
)
119+
else:
120+
raise ValueError("Something went wrong")
121+
else:
122+
planned = g_traj[start:end]
123+
124+
if len(planned) != pred_len:
125+
raise ValueError(
126+
f"Planned trajecotry length should be {pred_len} but it is {len(planned)} instead"
127+
)
115128
else:
116-
return g_traj[start:end]
129+
return planned
117130

118131
def isdone(self, curr_x, trajectory):
119132
"""

proj/environment/trajectories.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def compute_trajectory_stats(
5050
min_dist_travelled=150,
5151
mute=False,
5252
):
53+
54+
# ! shorten traj
55+
trajectory = trajectory[: planning_params["prediction_length"] + 1, :]
56+
5357
# Compute stats
5458
n_points = len(trajectory)
5559
distance_travelled = np.sum(
@@ -136,11 +140,11 @@ def compute_trajectory_stats(
136140
print(
137141
f"[bold red]Lookahead of {lookahead} is {perc_lookahead} of the # of waypoints, that might be too low. Values closer to 5% are advised."
138142
)
139-
if distance_travelled < min_dist_travelled * params["px_to_cm"]:
140-
logger.warning(
141-
"Distance travelled below minimal requirement, erroring"
142-
)
143-
return None, None
143+
# if distance_travelled < min_dist_travelled * params["px_to_cm"]:
144+
# logger.warning(
145+
# "Distance travelled below minimal requirement, erroring"
146+
# )
147+
# return None, None
144148

145149
return trajectory, duration, metadata
146150

proj/run/runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ def run_experiment(
117117
logger.exception(
118118
f"Failed to take next step in simulation.\nError: {e}\n\n"
119119
)
120+
break
121+
120122
try:
121123
environment.conclude()
122124
except:

run_allocentric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# plot_trajectory(env.reset())
2626

2727
# %%
28-
run_experiment(env, control, model, n_secs=0.1)
28+
run_experiment(env, control, model, n_secs=0.5)
2929

3030

3131
# %%

0 commit comments

Comments
 (0)