-
Notifications
You must be signed in to change notification settings - Fork 61
Description
Hi,
We're using roslibpy
as an interface for inferencing a PyTorch model and providing those predictions back to ROS. Our stack looks something like this:
- Cameras publish rgb/depth as ROS topics
- A python script subscribes to these topics
- Process incoming images at some frequency + use the pytorch model on the GPU to compute a prediction
- Publish prediction back as another ROS topic for the robot to consume
We had some success doing this with rclpy
directly, but due to dependency conflicts, we're moving to a standalone script + roslibpy
. A simplified version of our setup:
class HighLevelPolicy:
def __init__(
self,
run_id,
rosbridge_host="localhost",
rosbridge_port=9090,
):
self.device = "cuda"
# Initialize model
self.model = initialize_model(run_id, in_channels, self.device)
# Initialize roslibpy client
self.client = roslibpy.Ros(host=rosbridge_host, port=rosbridge_port)
self.client.run()
# Initialize subscribers
self.rgb_sub = roslibpy.Topic(
self.client, "/rgb/image_rect", "sensor_msgs/Image"
)
self.depth_sub = roslibpy.Topic(
self.client, "/depth_registered/image_rect", "sensor_msgs/Image"
)
# Initialize publisher
self.goal_pub = roslibpy.Topic(
self.client, "/goal_prediction", "sensor_msgs/PointCloud2"
)
# Data storage
self.latest_rgb = None
self.latest_depth = None
# Subscribe to topics
self.rgb_sub.subscribe(self.rgb_callback)
self.depth_sub.subscribe(self.depth_callback)
print("HighLevelPolicy initialized.")
def run_loop(self, interval=1.0):
"""Main loop that runs synchronously every `interval` seconds."""
try:
while True:
self.timer_callback()
time.sleep(interval)
except KeyboardInterrupt:
print("Interrupted by user. Shutting down.")
self.shutdown()
def rgb_callback(self, msg):
self.latest_rgb = msg
def depth_callback(self, msg):
self.latest_depth = msg
def camera_info_callback(self, msg):
self.camera_info = msg
def gripper_pcd_callback(self, msg):
self.latest_gripper_pcd = msg
def timer_callback(self):
if all([self.latest_rgb, self.latest_depth, self.camera_info]):
# Extract data from messages
rgb, depth = self.extract_images_from_messages(self.latest_rgb, self.latest_depth)
pcd_xyz = compute_pcd(rgb, depth)
# Run inference
goal_prediction = inference(self.model, pcd_xyz, self.device)
# Publish goal prediction
self.publish_msg(goal_prediction, header=self.latest_depth["header"])
else:
print("Waiting for images...")
def inference(model, pcd_xyz, device):
with torch.no_grad():
pcd_xyz = torch.from_numpy(pcd_xyz.astype(np.float32)).to(device)
goal_prediction = model(pcd_xyz)
return goal_prediction
policy = HighLevelPolicy(
run_id=args.run_id,
rosbridge_host=args.rosbridge_host,
rosbridge_port=args.rosbridge_port,
)
policy.run_loop()
The issue we're facing is that calling inference()
is dramatically slow once we've begun subscribing to topics. The default inference time is around 300ms and this holds true when I check before the first subscription (self.rgb_sub.subscribe(self.rgb_callback)
). However, if I check inference time after this line has been executed, it increases about 100x to 30seconds. If I unsubscribe from the topic, performance goes back to normal again.
It's unclear to me why subscribing to a topic with a lightweight callback should degrade GPU performance so significantly. I tried some things including moving the ros callbacks to another thread (did not help), setting queue_length=1
(did not help), setting throttle_rate=1000
(this did help a lot, but its still slow). We did not face any such issue with rclpy
, I assume it has a different threading model.
Any help with this would be much appreciated!