Skip to content

Commit b33b2d7

Browse files
authored
【Hackathon 9th Sprint No.87】feat: implement MultiDtypeGenerator for low-precision sample generation (#396)
* feat: implement MultiDtypeGenerator for low-precision sample generation - Add MultiDtypeGenerator class to generate float16/bfloat16/float8 samples from float32 samples - Automatically convert input and weight tensor dtypes - Keep batch_norm parameters (running_mean, running_var, scale, bias) as float32 - Add torch.autocast context manager to model.py forward method - Update graph_net.json metadata with dtype information - Add sample validation mechanism - Add MultiDtypeFilter for filtering invalid graphs - Add command-line tool for batch generation - Support test_compiler evaluation and Agent code generation workflows * refactor: optimize MultiDtypeGenerator code style and reduce nesting - Refactor _modify_weight_meta to use single-pass traversal - Extract _should_keep_weight_float32 helper method - Simplify _ensure_autocast_import to reduce duplicate file writes - Clean up unused variables in _validate_sample - Apply black code formatting * feat: implement MultiDtypeGenerator with FX Graph passes - Add MultiDtypeGenerator for generating low-precision (float16/bfloat16/float8) samples - Implement FX Graph passes for dtype conversion (no AST/string manipulation) - Preserve BatchNorm/LayerNorm buffers as float32 for numerical stability - Add autocast metadata for runtime mixed-precision support - Thread-safe extraction with proper locking - Comprehensive error handling and device detection - Validated on CV (ResNet18) and NLP (BERT) samples Architecture: - graph_net/torch/multi_dtype_generator.py: Main generator with RunModelDecorator - graph_net/torch/multi_dtype_passes/: FX Graph pass implementations - pass_base.py: Base class for dtype conversion passes - dtype_conversion_pass.py: Concrete dtype conversion implementation - autocast_wrapper_pass.py: Autocast metadata injection - pass_mgr.py: Pass manager Features: - Real FX Graph passes (not string replacement) - Smart weight preservation (BN/LN buffers stay float32) - Automatic device type detection - Compatible with test_compiler and Agent workflows - Extensible filter system for unsupported operations Fixes: - Correct val_map semantics in FX Graph rewriting - Proper dtype detection (avoid false positives on int tensors) - CPU fallback when CUDA unavailable - Import organization and error handling * fix: resolve ruff linter errors in multi_dtype_passes - Add explicit re-export in __init__.py for DtypeConversionPass - Use TYPE_CHECKING for forward reference in pass_mgr.py - Remove unused torch import in autocast_wrapper_pass.py - Fix type annotation to use type[DtypeConversionPass] * feat: add data type generalization passes with FX Graph architecture - Add InitDataTypeGeneralizationPasses class for testing and initializing dtype conversion passes - Implement DtypeConversionPass base class and ConcretePass for dtype conversion - Add AutocastWrapperPass for handling operators not supporting low precision - Store applicable pass names in graph_net.json under kDataTypeGeneralizationPasses - Use FX Graph passes instead of AST/string manipulation for robustness - Add atomic file writing for graph_net.json to prevent corruption - Add comprehensive weight preservation logic for BatchNorm/LayerNorm parameters - Add RunModelPredicator integration for testing graph runnability - Refine exception handling with specific exception types - Auto-detect device type for autocast (cuda/cpu) - Organize passes in multi_dtype_passes/ directory with proper management This feature enables generating low-precision (float16/bfloat16/float8) samples from float32 computation graphs, following the initializer pattern where passes are tested offline and applied dynamically at runtime. * Refactor dtype generalization to two-step architecture per reviewer feedback * Fix integer tensor type detection in dtype generalization pass * Resolve merge conflicts: merge dtype and develop branches * Fix code formatting: apply black formatter * Fix code formatting: add missing blank line * Refactor: fix update_json API and add dtype generalization field constants * Fix: create separate pass files for each dtype to match graph_net.json * Refactor: simplify test script and extract helper functions in dtype pass
1 parent 09d7847 commit b33b2d7

File tree

11 files changed

+878
-5
lines changed

11 files changed

+878
-5
lines changed
Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,61 @@
1-
from pathlib import Path
21
import json
2+
from pathlib import Path
33

44
kDimensionGeneralizationPasses = "dimension_generalization_passes"
5+
kDataTypeGeneralizationPasses = "data_type_generalization_passes"
56
kSymbolicDimensionReifier = "symbolic_dimension_reifier"
67

8+
# Fields for dtype generalization metadata
9+
kDtypeGeneralizationTargetDtype = "dtype_generalization_target_dtype"
10+
kDtypeGeneralizationPrecision = "dtype_generalization_precision"
11+
kDtypeGeneralizationGenerated = "dtype_generalization_generated"
12+
713

814
def read_json(model_path):
15+
"""
16+
Read JSON from graph_net.json file.
17+
18+
Args:
19+
model_path: Path to model directory
20+
21+
Returns:
22+
Dictionary containing JSON data
23+
"""
924
graph_net_json_file_path = Path(f"{model_path}/graph_net.json")
1025
return json.loads(graph_net_json_file_path.read_text())
1126

1227

1328
def update_json(model_path, field, value):
14-
graph_net_json_file_path = Path(f"{model_path}/graph_net.json")
15-
graph_net_json = json.loads(graph_net_json_file_path.read_text())
29+
"""
30+
Update a single field in graph_net.json.
31+
32+
Args:
33+
model_path: Path to model directory or graph_net.json file
34+
field: Field name to update
35+
value: Value to set
36+
"""
37+
if isinstance(model_path, (str, Path)):
38+
model_path = Path(model_path)
39+
# If it's a file path, use it directly; otherwise assume it's a directory
40+
if model_path.suffix == ".json":
41+
graph_net_json_file_path = model_path
42+
else:
43+
graph_net_json_file_path = model_path / "graph_net.json"
44+
else:
45+
graph_net_json_file_path = Path(f"{model_path}/graph_net.json")
46+
47+
# Read existing JSON
48+
if graph_net_json_file_path.exists():
49+
with open(graph_net_json_file_path, "r") as f:
50+
graph_net_json = json.load(f)
51+
else:
52+
graph_net_json = {}
53+
54+
# Update field
1655
graph_net_json[field] = value
17-
graph_net_json_file_path.write_text(json.dumps(graph_net_json, indent=4))
56+
57+
# Atomic write: write to temp file then rename
58+
temp_path = graph_net_json_file_path.with_suffix(".json.tmp")
59+
with open(temp_path, "w") as f:
60+
json.dump(graph_net_json, f, indent=4)
61+
temp_path.replace(graph_net_json_file_path)

graph_net/test/dtype_gen_test.sh

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
5+
SAMPLES_ROOT="$GRAPH_NET_ROOT/../samples"
6+
OUTPUT_DIR="/tmp/dtype_gen_samples"
7+
mkdir -p "$OUTPUT_DIR"
8+
9+
# Step 1: Initialize dtype generalization passes
10+
config_json_str_init=$(cat <<EOF
11+
{
12+
"handler_path": "$GRAPH_NET_ROOT/torch/dtype_generalizer.py",
13+
"handler_class_name": "InitDataTypeGeneralizationPasses",
14+
"handler_config": {
15+
"dtype_list": ["float16", "bfloat16"],
16+
"model_path_prefix": "$SAMPLES_ROOT"
17+
}
18+
}
19+
EOF
20+
)
21+
CONFIG_INIT=$(echo "$config_json_str_init" | base64 -w 0)
22+
23+
python3 -m graph_net.model_path_handler --model-path "timm/resnet18" --handler-config=$CONFIG_INIT
24+
python3 -m graph_net.model_path_handler --model-path "transformers-auto-model/opus-mt-en-gmw" --handler-config=$CONFIG_INIT
25+
26+
# Step 2: Apply passes to generate samples
27+
config_json_str_apply=$(cat <<EOF
28+
{
29+
"handler_path": "$GRAPH_NET_ROOT/torch/dtype_generalizer.py",
30+
"handler_class_name": "ApplyDataTypeGeneralizationPasses",
31+
"handler_config": {
32+
"output_dir": "$OUTPUT_DIR",
33+
"model_path_prefix": "$SAMPLES_ROOT",
34+
"model_runnable_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
35+
"model_runnable_predicator_class_name": "RunModelPredicator",
36+
"model_runnable_predicator_config": {
37+
"use_dummy_inputs": true
38+
}
39+
}
40+
}
41+
EOF
42+
)
43+
CONFIG_APPLY=$(echo "$config_json_str_apply" | base64 -w 0)
44+
45+
python3 -m graph_net.model_path_handler --model-path "timm/resnet18" --handler-config=$CONFIG_APPLY
46+
python3 -m graph_net.model_path_handler --model-path "transformers-auto-model/opus-mt-en-gmw" --handler-config=$CONFIG_APPLY
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from graph_net.torch.dtype_gen_passes.pass_base import (
2+
DtypeGeneralizationPass as DtypeGeneralizationPass,
3+
)
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""
2+
Concrete implementation of dtype generalization pass.
3+
4+
This pass converts tensor dtypes in FX Graph by:
5+
1. Converting placeholder nodes (inputs) to target dtype
6+
2. Converting get_attr nodes (weights) to target dtype, except preserved weights
7+
3. Inserting .to(dtype) calls where needed
8+
"""
9+
10+
import torch
11+
import torch.fx as fx
12+
from graph_net.torch.dtype_gen_passes.pass_base import DtypeGeneralizationPass
13+
14+
15+
class ConcretePass(DtypeGeneralizationPass):
16+
"""
17+
FX Graph pass that converts dtypes of tensors.
18+
19+
This pass modifies the graph to:
20+
- Convert input tensors to target dtype
21+
- Convert weight tensors to target dtype (except preserved weights)
22+
- Insert dtype conversion nodes where necessary
23+
"""
24+
25+
def get_pass_name(self) -> str:
26+
return f"dtype_generalization_{self.target_dtype}"
27+
28+
def need_rewrite(self, gm: fx.GraphModule) -> bool:
29+
"""
30+
Check if graph has float32 tensors that need conversion.
31+
"""
32+
for node in gm.graph.nodes:
33+
if self._node_need_rewrite(node):
34+
return True
35+
return False
36+
37+
def _node_need_rewrite(self, node: fx.Node) -> bool:
38+
"""
39+
Check if a specific node needs dtype conversion.
40+
41+
Args:
42+
node: FX Node to check
43+
44+
Returns:
45+
True if node should be rewritten
46+
"""
47+
# Check placeholder nodes (inputs)
48+
if node.op == "placeholder":
49+
return self._is_float32_tensor(node)
50+
51+
# Check get_attr nodes (weights)
52+
if node.op == "get_attr":
53+
if self._is_float32_tensor(node):
54+
# Only rewrite if not in preserve list
55+
attr_name = str(node.target)
56+
return not self.should_preserve_weight(attr_name)
57+
58+
return False
59+
60+
def rewrite(self, gm: fx.GraphModule) -> fx.GraphModule:
61+
"""
62+
Rewrite the graph to convert dtypes.
63+
64+
Strategy:
65+
1. For each placeholder (input), insert .to(target_dtype) after it
66+
2. For each get_attr (weight), insert .to(target_dtype) if not preserved
67+
3. Update the graph and recompile
68+
"""
69+
new_graph = fx.Graph()
70+
val_map = {}
71+
72+
def create_placeholder(node: fx.Node) -> fx.Node:
73+
"""Create a placeholder node with dtype conversion if needed."""
74+
new_node = new_graph.node_copy(node, lambda x: val_map.get(x, x))
75+
if self._is_float32_tensor(node):
76+
return new_graph.call_method("to", args=(new_node, self.torch_dtype))
77+
return new_node
78+
79+
def create_get_attr(node: fx.Node) -> fx.Node:
80+
"""Create a get_attr node with dtype conversion if needed."""
81+
new_node = new_graph.node_copy(node, lambda x: val_map.get(x, x))
82+
attr_name = str(node.target)
83+
if self._is_float32_tensor(node) and not self.should_preserve_weight(
84+
attr_name
85+
):
86+
return new_graph.call_method("to", args=(new_node, self.torch_dtype))
87+
return new_node
88+
89+
for node in gm.graph.nodes:
90+
if node.op == "placeholder":
91+
val_map[node] = create_placeholder(node)
92+
elif node.op == "get_attr":
93+
val_map[node] = create_get_attr(node)
94+
else:
95+
new_node = new_graph.node_copy(node, lambda x: val_map.get(x, x))
96+
val_map[node] = new_node
97+
98+
# Replace the graph
99+
gm.graph = new_graph
100+
gm.recompile()
101+
102+
return gm
103+
104+
def _is_float32_tensor(self, node: fx.Node) -> bool:
105+
"""
106+
Check if a node represents a float32 tensor.
107+
108+
Args:
109+
node: FX Node to check
110+
111+
Returns:
112+
True if node is a float32 tensor
113+
"""
114+
# Check tensor_meta if available (most reliable)
115+
if "tensor_meta" in node.meta:
116+
tensor_meta = node.meta["tensor_meta"]
117+
if hasattr(tensor_meta, "dtype"):
118+
return tensor_meta.dtype == torch.float32
119+
120+
# For placeholder and get_attr nodes without metadata,
121+
# we need to be conservative and only return True if explicitly float
122+
if node.op in ("placeholder", "get_attr"):
123+
# Check type annotation if available
124+
if node.type is not None:
125+
type_str = str(node.type).lower()
126+
127+
# Explicitly check for integer types - these should NOT be converted
128+
integer_types = ["long", "int", "short", "byte", "bool"]
129+
if any(int_type in type_str for int_type in integer_types):
130+
return False
131+
132+
# Only return True if explicitly a floating point tensor
133+
# Check for explicit float types: FloatTensor, float32, float16, etc.
134+
float_indicators = ["float", "double", "half", "bfloat"]
135+
if any(
136+
float_indicator in type_str for float_indicator in float_indicators
137+
):
138+
return True
139+
140+
# For generic "Tensor" without explicit dtype, be conservative
141+
# Don't assume it's float32 - it might be integer
142+
return False
143+
144+
return False
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""
2+
Dtype generalization pass for bfloat16.
3+
4+
This pass converts float32 tensors to bfloat16.
5+
"""
6+
7+
from graph_net.torch.dtype_gen_passes.dtype_generalization_pass import (
8+
ConcretePass as BaseConcretePass,
9+
)
10+
11+
# Weights that must remain float32 for numerical stability
12+
FLOAT32_PRESERVED_WEIGHTS = {
13+
"running_mean",
14+
"running_var",
15+
"num_batches_tracked",
16+
"bn_parameters_weight",
17+
"bn_parameters_bias",
18+
"ln_parameters_weight",
19+
"ln_parameters_bias",
20+
}
21+
22+
23+
class ConcretePass(BaseConcretePass):
24+
"""
25+
FX Graph pass that converts dtypes to bfloat16.
26+
"""
27+
28+
def __init__(self, *args, **kwargs):
29+
# Override target_dtype to bfloat16
30+
super().__init__(
31+
target_dtype="bfloat16",
32+
preserve_weights=FLOAT32_PRESERVED_WEIGHTS,
33+
*args,
34+
**kwargs,
35+
)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""
2+
Dtype generalization pass for float16.
3+
4+
This pass converts float32 tensors to float16.
5+
"""
6+
7+
from graph_net.torch.dtype_gen_passes.dtype_generalization_pass import (
8+
ConcretePass as BaseConcretePass,
9+
)
10+
11+
# Weights that must remain float32 for numerical stability
12+
FLOAT32_PRESERVED_WEIGHTS = {
13+
"running_mean",
14+
"running_var",
15+
"num_batches_tracked",
16+
"bn_parameters_weight",
17+
"bn_parameters_bias",
18+
"ln_parameters_weight",
19+
"ln_parameters_bias",
20+
}
21+
22+
23+
class ConcretePass(BaseConcretePass):
24+
"""
25+
FX Graph pass that converts dtypes to float16.
26+
"""
27+
28+
def __init__(self, *args, **kwargs):
29+
# Override target_dtype to float16
30+
super().__init__(
31+
target_dtype="float16",
32+
preserve_weights=FLOAT32_PRESERVED_WEIGHTS,
33+
*args,
34+
**kwargs,
35+
)

0 commit comments

Comments
 (0)