-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Description
Outline & Motivation
While looking into throughput monitoring, I noticed that dtype inference currently relies on explicit isinstance(...) checks against various precision plugin classes.
In both Fabric and PyTorch, the helper _plugin_to_compute_dtype(...):
- Imports multiple precision plugin implementations
- Performs chained type checks
- Accesses internal attributes like _desired_input_dtype and _desired_dtype
- Returns hard-coded dtype mappings per plugin
Although this works, it tightly couples throughput utilities to specific precision plugin implementations. As a result:
- Adding new precision plugins may require modifying throughput logic
- Internal plugin attributes are accessed outside the plugin itself
- Custom precision plugins cannot integrate seamlessly without touching throughput utilities
It feels like the responsibility for exposing the compute dtype should live within the precision plugin itself rather than being inferred externally.
Pitch
I propose introducing a small API on the Precision base class:
def compute_dtype(self) -> torch.dtype:
return torch.float32
Each built-in precision plugin would override this method to return its effective compute dtype (e.g., half precision returning torch.float16, transformer engine returning torch.int8, etc.).
Throughput utilities would then simply call:
plugin.compute_dtype()
This would allow us to:
- Remove all isinstance(...) checks
- Eliminate plugin-specific imports inside throughput utilities
- Stop accessing internal precision attributes from outside the plugin
The refactor would be internal and fully backward compatible since the base implementation defaults to torch.float32.
Additional context
This change aligns well with Lightning’s plugin-based architecture philosophy — where each plugin encapsulates its own behavior and metadata.
It would also:
- Reduce maintenance overhead when adding new precision backends
- Automatically support custom precision plugins in throughput tooling
- Improve separation of concerns
If this direction makes sense, I’d be happy to open a PR implementing the change across Fabric and PyTorch precision plugins with tests.
Would love to hear your thoughts.