Skip to content

Commit a7d3749

Browse files
committed
goal plotting
1 parent 1eb556e commit a7d3749

File tree

3 files changed

+33
-10
lines changed

3 files changed

+33
-10
lines changed

proj/environment/manager.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,15 @@ def _save_results(self):
111111
)
112112

113113
# save the last frame as a results image
114-
last_frame = [
115-
f for f in self.frames_folder.glob("*.png") if f.is_file()
116-
][-1]
117-
shutil.copy(str(last_frame), str(self.datafolder / "final_frame.png"))
114+
try:
115+
last_frame = [
116+
f for f in self.frames_folder.glob("*.png") if f.is_file()
117+
][-1]
118+
shutil.copy(
119+
str(last_frame), str(self.datafolder / "final_frame.png")
120+
)
121+
except IndexError:
122+
pass # no frames were saved
118123

119124
# save cost history
120125
pd.DataFrame(self.cost_history).to_hdf(
@@ -149,7 +154,7 @@ def conclude(self):
149154
self._log_conf()
150155
self._save_results()
151156

152-
if self.model.PLOT_LIVE:
157+
if self.model.LIVE_PLOT:
153158
self._save_video()
154159

155160
# save summary plot

proj/environment/plotter.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,11 +238,26 @@ def _plot_cost(self, keep_s=1.2):
238238
)
239239
ax.legend()
240240

241-
# def _plot_goal(self, goal):
242-
# goal = self.model._goal(*goal)
243-
# x = self.model.curr_x
241+
def _plot_goal(self, goal):
242+
self.goal_ax.clear()
243+
goal = self.model._state(*goal)
244+
x = self.model.curr_x
245+
246+
for n, k in enumerate(x._fields):
247+
self.goal_ax.bar(n, goal._asdict()[k], color=colors[k], alpha=0.7)
248+
self.goal_ax.scatter(
249+
n,
250+
x._asdict()[k],
251+
color=desaturate_color(colors[k]),
252+
s=200,
253+
zorder=99,
254+
lw=1,
255+
edgecolors="white",
256+
label=k,
257+
)
244258

245-
# for k in self.model.curr_x._fields:
259+
self.goal_ax.legend()
260+
self.goal_ax.set(xticks=[])
246261

247262
def visualize_world_live(self, curr_goals, elapsed=None):
248263
ax = self.xy_ax
@@ -289,6 +304,9 @@ def visualize_world_live(self, curr_goals, elapsed=None):
289304
# plot cost
290305
self._plot_cost()
291306

307+
# plot goal
308+
self._plot_goal(curr_goals[0, :])
309+
292310
# display plot
293311
self.f.canvas.draw()
294312
plt.pause(0.00001)

proj/model/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ class Config:
66

77
USE_FAST = True # if true use cumba's methods
88
SPAWN_TYPE = "trajectory"
9-
LIVE_PLOT = False
9+
LIVE_PLOT = True
1010

1111
# ----------------------------- Simulation params ---------------------------- #
1212
dt = 0.005

0 commit comments

Comments
 (0)