Skip to content

Commit 36b21c4

Browse files
authored
Implement visualizing trajectories (#99)
1 parent e340ca0 commit 36b21c4

File tree

1 file changed

+104
-1
lines changed

1 file changed

+104
-1
lines changed

ada_feeding/ada_feeding/behaviors/move_to.py

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55
logic to move the robot arm using pymoveit2.
66
"""
77
# Standard imports
8+
import csv
89
import math
10+
import os
911
import time
1012

1113
# Third-party imports
1214
from action_msgs.msg import GoalStatus
15+
from ament_index_python.packages import get_package_share_directory
1316
from moveit_msgs.msg import MoveItErrorCodes
1417
import py_trees
1518
from pymoveit2 import MoveIt2State
@@ -44,6 +47,7 @@ def __init__(
4447
terminate_timeout_s: float = 10.0,
4548
terminate_rate_hz: float = 30.0,
4649
planning_service_timeout_s: float = 10.0,
50+
save_trajectory_viz: bool = False,
4751
):
4852
"""
4953
Initialize the MoveTo class.
@@ -62,15 +66,20 @@ def __init__(
6266
processed.
6367
planning_service_timeout_s: How long to wait for the planning service to be
6468
ready before failing.
69+
save_trajectory_viz: Whether to generate and save a visualization of the
70+
trajectory in joint space. This is useful for debugging, but should
71+
be disabled in production.
6572
"""
6673
# Initiatilize the behavior
6774
super().__init__(name=name)
6875

6976
# Store parameters
7077
self.node = node
78+
self.tree_name = tree_name
7179
self.terminate_timeout_s = terminate_timeout_s
7280
self.terminate_rate_hz = terminate_rate_hz
7381
self.planning_service_timeout_s = planning_service_timeout_s
82+
self.save_trajectory_viz = save_trajectory_viz
7483

7584
# Initialization the blackboard for this behavior
7685
self.move_to_blackboard = self.attach_blackboard_client(
@@ -116,7 +125,7 @@ def __init__(
116125

117126
# Initialize the blackboard to read from the parent behavior tree
118127
self.tree_blackboard = self.attach_blackboard_client(
119-
name=name + " MoveTo", namespace=tree_name
128+
name=name + " MoveTo", namespace=self.tree_name
120129
)
121130
# Feedback from MoveTo for the ROS2 Action Server
122131
self.tree_blackboard.register_key(
@@ -294,6 +303,10 @@ def update(self) -> py_trees.common.Status:
294303
if self.cartesian and self.moveit2.max_velocity > 0.0:
295304
MoveTo.scale_velocity(self.traj, self.moveit2.max_velocity)
296305

306+
# Save the trajectory visualization
307+
if self.save_trajectory_viz:
308+
MoveTo.visualize_trajectory(self.tree_name, self.traj)
309+
297310
# Set the trajectory's initial distance to goal
298311
self.tree_blackboard.motion_initial_distance = float(
299312
len(self.traj.points)
@@ -444,6 +457,11 @@ def joint_position_dist(point1: float, point2: float) -> float:
444457
"""
445458
Given two joint positions in radians, this function computes the
446459
distance between then, accounting for rotational symmetry.
460+
461+
Parameters
462+
----------
463+
point1: The first joint position, in radians.
464+
point2: The second joint position, in radians.
447465
"""
448466
abs_dist = abs(point1 - point2) % (2 * math.pi)
449467
return min(abs_dist, 2 * math.pi - abs_dist)
@@ -510,3 +528,88 @@ def get_distance_to_goal(self) -> float:
510528
# Because the robot may still have slight motion even after this point,
511529
# we conservatively return 1.0.
512530
return 1.0
531+
532+
@staticmethod
533+
def visualize_trajectory(action_name: str, traj: JointTrajectory) -> None:
534+
"""
535+
Generates a visualization of the positions, velocities, and accelerations
536+
of each joint in the trajectory. Saves the visualization in the share
537+
directory for `ada_feeding`, as `trajectories/{timestamp}_{action}.png`.
538+
Also saves a CSV of the trajectory.
539+
540+
Parameters
541+
----------
542+
action_name: The name of the action that generated this trajectory.
543+
traj: The trajectory to visualize.
544+
"""
545+
546+
# pylint: disable=too-many-locals
547+
# Necessary because this function saves both the image and CSV
548+
549+
# pylint: disable=import-outside-toplevel
550+
# No need to import graphing libraries if we aren't saving the trajectory
551+
import matplotlib.pyplot as plt
552+
553+
# Get the filepath, excluding the extension
554+
file_dir = os.path.join(
555+
get_package_share_directory("ada_feeding"),
556+
"trajectories",
557+
)
558+
if not os.path.exists(file_dir):
559+
os.mkdir(file_dir)
560+
filepath = os.path.join(
561+
file_dir,
562+
f"{int(time.time()*10**9)}_{action_name}",
563+
)
564+
565+
# Generate the CSV header
566+
csv_header = ["time_from_start"]
567+
for descr in ["Position", "Velocity", "Acceleration"]:
568+
for joint_name in traj.joint_names:
569+
csv_header.append(f"{joint_name} {descr}")
570+
csv_data = [csv_header]
571+
572+
# Generate the axes for the graph
573+
nrows = 2
574+
ncols = int(math.ceil(len(traj.joint_names) / float(nrows)))
575+
fig, axes = plt.subplots(nrows, ncols, figsize=(20, 10))
576+
positions = [[] for _ in traj.joint_names]
577+
velocities = [[] for _ in traj.joint_names]
578+
accelerations = [[] for _ in traj.joint_names]
579+
time_from_start = []
580+
581+
# Loop over the trajectory
582+
for point in traj.points:
583+
timestamp = (
584+
point.time_from_start.sec + point.time_from_start.nanosec * 10.0**-9
585+
)
586+
row = [timestamp]
587+
time_from_start.append(timestamp)
588+
for descr in ["positions", "velocities", "accelerations"]:
589+
for i, joint_name in enumerate(traj.joint_names):
590+
row.append(getattr(point, descr)[i])
591+
if descr == "positions":
592+
positions[i].append(getattr(point, descr)[i])
593+
elif descr == "velocities":
594+
velocities[i].append(getattr(point, descr)[i])
595+
elif descr == "accelerations":
596+
accelerations[i].append(getattr(point, descr)[i])
597+
csv_data.append(row)
598+
599+
# Generate and save the figure
600+
for i, joint_name in enumerate(traj.joint_names):
601+
ax = axes[i // ncols, i % ncols]
602+
ax.plot(time_from_start, positions[i], label="Position (rad)")
603+
ax.plot(time_from_start, velocities[i], label="Velocity (rad/s)")
604+
ax.plot(time_from_start, accelerations[i], label="Acceleration (rad/s^2)")
605+
ax.set_xlabel("Time (s)")
606+
ax.set_title(joint_name)
607+
ax.legend()
608+
fig.tight_layout()
609+
fig.savefig(filepath + ".png")
610+
plt.clf()
611+
612+
# Save the CSV
613+
with open(filepath + ".csv", "w", encoding="utf-8") as csv_file:
614+
csv_writer = csv.writer(csv_file)
615+
csv_writer.writerows(csv_data)

0 commit comments

Comments
 (0)