Commit 7971d41
committed
feat: implement MultiDtypeGenerator with FX Graph passes
- Add MultiDtypeGenerator for generating low-precision (float16/bfloat16/float8) samples
- Implement FX Graph passes for dtype conversion (no AST/string manipulation)
- Preserve BatchNorm/LayerNorm buffers as float32 for numerical stability
- Add autocast metadata for runtime mixed-precision support
- Thread-safe extraction with proper locking
- Comprehensive error handling and device detection
- Validated on CV (ResNet18) and NLP (BERT) samples
Architecture:
- graph_net/torch/multi_dtype_generator.py: Main generator with RunModelDecorator
- graph_net/torch/multi_dtype_passes/: FX Graph pass implementations
- pass_base.py: Base class for dtype conversion passes
- dtype_conversion_pass.py: Concrete dtype conversion implementation
- autocast_wrapper_pass.py: Autocast metadata injection
- pass_mgr.py: Pass manager
Features:
- Real FX Graph passes (not string replacement)
- Smart weight preservation (BN/LN buffers stay float32)
- Automatic device type detection
- Compatible with test_compiler and Agent workflows
- Extensible filter system for unsupported operations
Fixes:
- Correct val_map semantics in FX Graph rewriting
- Proper dtype detection (avoid false positives on int tensors)
- CPU fallback when CUDA unavailable
- Import organization and error handling1 parent 995d7bf commit 7971d41
File tree
7 files changed
+562
-458
lines changed- graph_net/torch
- multi_dtype_passes
7 files changed
+562
-458
lines changedThis file was deleted.
0 commit comments