Skip to content

Subscribing to topics dramatically slows down model inference on GPUΒ #132

@sriramsk1999

Description

@sriramsk1999

Hi,

We're using roslibpy as an interface for inferencing a PyTorch model and providing those predictions back to ROS. Our stack looks something like this:

  1. Cameras publish rgb/depth as ROS topics
  2. A python script subscribes to these topics
  3. Process incoming images at some frequency + use the pytorch model on the GPU to compute a prediction
  4. Publish prediction back as another ROS topic for the robot to consume

We had some success doing this with rclpy directly, but due to dependency conflicts, we're moving to a standalone script + roslibpy. A simplified version of our setup:

class HighLevelPolicy:
    def __init__(
        self,
        run_id,
        rosbridge_host="localhost",
        rosbridge_port=9090,
    ):
        self.device = "cuda"
        # Initialize model
        self.model = initialize_model(run_id, in_channels, self.device)

        # Initialize roslibpy client
        self.client = roslibpy.Ros(host=rosbridge_host, port=rosbridge_port)
        self.client.run()

        # Initialize subscribers
        self.rgb_sub = roslibpy.Topic(
            self.client, "/rgb/image_rect", "sensor_msgs/Image"
        )
        self.depth_sub = roslibpy.Topic(
            self.client, "/depth_registered/image_rect", "sensor_msgs/Image"
        )
        # Initialize publisher
        self.goal_pub = roslibpy.Topic(
            self.client, "/goal_prediction", "sensor_msgs/PointCloud2"
        )

        # Data storage
        self.latest_rgb = None
        self.latest_depth = None

        # Subscribe to topics
        self.rgb_sub.subscribe(self.rgb_callback)
        self.depth_sub.subscribe(self.depth_callback)
        print("HighLevelPolicy initialized.")

    def run_loop(self, interval=1.0):
        """Main loop that runs synchronously every `interval` seconds."""
        try:
            while True:
                self.timer_callback()
                time.sleep(interval)
        except KeyboardInterrupt:
            print("Interrupted by user. Shutting down.")
            self.shutdown()

    def rgb_callback(self, msg):
        self.latest_rgb = msg

    def depth_callback(self, msg):
        self.latest_depth = msg

    def camera_info_callback(self, msg):
        self.camera_info = msg

    def gripper_pcd_callback(self, msg):
        self.latest_gripper_pcd = msg

    def timer_callback(self):
        if all([self.latest_rgb, self.latest_depth, self.camera_info]):
            # Extract data from messages
            rgb, depth = self.extract_images_from_messages(self.latest_rgb, self.latest_depth)
            pcd_xyz = compute_pcd(rgb, depth)

            # Run inference
            goal_prediction = inference(self.model, pcd_xyz, self.device)

            # Publish goal prediction
            self.publish_msg(goal_prediction, header=self.latest_depth["header"])
        else:
            print("Waiting for images...")

def inference(model, pcd_xyz, device):
    with torch.no_grad():
        pcd_xyz = torch.from_numpy(pcd_xyz.astype(np.float32)).to(device)
        goal_prediction = model(pcd_xyz)
        return goal_prediction

policy = HighLevelPolicy(
        run_id=args.run_id,
        rosbridge_host=args.rosbridge_host,
        rosbridge_port=args.rosbridge_port,
    )
policy.run_loop()

The issue we're facing is that calling inference() is dramatically slow once we've begun subscribing to topics. The default inference time is around 300ms and this holds true when I check before the first subscription (self.rgb_sub.subscribe(self.rgb_callback)). However, if I check inference time after this line has been executed, it increases about 100x to 30seconds. If I unsubscribe from the topic, performance goes back to normal again.

It's unclear to me why subscribing to a topic with a lightweight callback should degrade GPU performance so significantly. I tried some things including moving the ros callbacks to another thread (did not help), setting queue_length=1 (did not help), setting throttle_rate=1000 (this did help a lot, but its still slow). We did not face any such issue with rclpy, I assume it has a different threading model.

Any help with this would be much appreciated!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions