Skip to content

[Feature Request] Natively support training and inference on different PyTorch devices #2194

@ogiarch

Description

@ogiarch

🚀 Feature

Natively allow SB3 models to be set up with two different PyTorch devices: one "inference device" used for methods like predict(...) and the other "training device" used for methods like learn(...). When learn(...) is invoked, e.g., move the model to the "training device" and run training. When predict(...) is invoked, move the model back to the inference device for inference.

Motivation

Disclaimer: I am a professional software engineer, but am completely self-taught when it comes to RL, so my general coding ability is a little better than my deep understanding of these algorithms. Take everything I say with appropriate grains of salt.

For about a year now in my spare time, I have been working on a project using a custom extension to the sb3_contrib Maskable PPO class for a personal board game development project, and have been working on trying to speed up long-running simulations on my macOS laptop. (Note: in case my use case affects my results here, it probably bears mentioning that because I am using it for MARL, with multiple models sharing the same environment, I am abusing the algorithm a little bit by limiting the number of parallel envs to 1. I hope to change this in the future, but it's going to be a heavy lift and will get into some gnarly multiprocessing stuff that is going to be a real slog to build out in Python.)

For a long time I shied away from trying to run anything on my GPU, because the official SB3 documentation on PPO states very clearly that "PPO is meant to be run primarily on the CPU, especially when you are not using a CNN." I figured that the reason for this guidance was that running inference on GPUs tends to dramatically worsen performance due to the fact that 1) inference needs to be run in every single training loop and interfaces heavily with a model's environment, 2) for the time being, most Gymnasium environments and simulations are very CPU-bound, and 3) roundtrips between CPU and GPU memory are way too slow to be running in every simulation loop. Experimentally, I found this to be the case when using py-spy to inspect runtime, with the vast majority of CPU time spent on shipping experiences to and especially collecting selected actions from the GPU when the GPU is used for inference.

However, we don't need to use the same PyTorch device for training and inference in our models, right? Training is particularly slow on CPUs and particularly fast on many GPUs, and because it runs relatively infrequently, I figured it might be a good idea to temporarily ship my whole model to the GPU before a training round and then back to the CPU before inference starts up right afterward. Other than a nasty bug causing model optimizer classes to be loaded on the wrong device, it was relatively straightforward to overload a few methods to build a model class that took separate arguments for an inference_device and a training_device and did what I wanted here.

The results were an absolutely massive speed-up in computation. Training rounds went from taking two seconds to about a half-second on average when shifted to 'mps', and this was particularly true when running multiple simulations at once. Parallel "tournament rounds" of thousands of games I'd set up went from taking 2.5-3 hrs to about 70 minutes to complete. The improvement in efficiency was to the point where I thought that some version of this functionality should be broadly available by default in SB3. If the results remain statistically sound, it seems like by default we're leaving a ton of performance on the table for most computers capable of running PPO and other related algorithms.

Curious to hear others' thoughts, though, especially as I'm new to the space.

Alternatives

  1. Given that for many pure RL use cases, it's very often a bad idea to run inference on anything but the CPU, another option could be to just use the inputted device for training and always use the CPU for inference. I don't like this, though, as it would mean a regression in SB3's device flexibility today.
  2. We could just... not do any of this. As I've said, though, unless I'm missing something really big when it comes to statistical accuracy when switching off between different devices like this, I think this would mean a missed opportunity for an enormous increase in model performance for many developers.

Checklist

  • I have checked that there is no similar issue in the repo
  • If I'm requesting a new feature, I have proposed alternatives

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions