Skip to content

Commit 7971d41

Browse files
committed
feat: implement MultiDtypeGenerator with FX Graph passes
- Add MultiDtypeGenerator for generating low-precision (float16/bfloat16/float8) samples - Implement FX Graph passes for dtype conversion (no AST/string manipulation) - Preserve BatchNorm/LayerNorm buffers as float32 for numerical stability - Add autocast metadata for runtime mixed-precision support - Thread-safe extraction with proper locking - Comprehensive error handling and device detection - Validated on CV (ResNet18) and NLP (BERT) samples Architecture: - graph_net/torch/multi_dtype_generator.py: Main generator with RunModelDecorator - graph_net/torch/multi_dtype_passes/: FX Graph pass implementations - pass_base.py: Base class for dtype conversion passes - dtype_conversion_pass.py: Concrete dtype conversion implementation - autocast_wrapper_pass.py: Autocast metadata injection - pass_mgr.py: Pass manager Features: - Real FX Graph passes (not string replacement) - Smart weight preservation (BN/LN buffers stay float32) - Automatic device type detection - Compatible with test_compiler and Agent workflows - Extensible filter system for unsupported operations Fixes: - Correct val_map semantics in FX Graph rewriting - Proper dtype detection (avoid false positives on int tensors) - CPU fallback when CUDA unavailable - Import organization and error handling
1 parent 995d7bf commit 7971d41

File tree

7 files changed

+562
-458
lines changed

7 files changed

+562
-458
lines changed

graph_net/torch/multi_dtype_generate.py

Lines changed: 0 additions & 84 deletions
This file was deleted.

0 commit comments

Comments
 (0)