Skip to content

Commit cafe77c

Browse files
committed
Fix: create separate pass files for each dtype to match graph_net.json
1 parent 90afcc9 commit cafe77c

File tree

3 files changed

+87
-31
lines changed

3 files changed

+87
-31
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""
2+
Dtype generalization pass for bfloat16.
3+
4+
This pass converts float32 tensors to bfloat16.
5+
"""
6+
7+
from graph_net.torch.dtype_gen_passes.dtype_generalization_pass import (
8+
ConcretePass as BaseConcretePass,
9+
)
10+
11+
# Weights that must remain float32 for numerical stability
12+
FLOAT32_PRESERVED_WEIGHTS = {
13+
"running_mean",
14+
"running_var",
15+
"num_batches_tracked",
16+
"bn_parameters_weight",
17+
"bn_parameters_bias",
18+
"ln_parameters_weight",
19+
"ln_parameters_bias",
20+
}
21+
22+
23+
class ConcretePass(BaseConcretePass):
24+
"""
25+
FX Graph pass that converts dtypes to bfloat16.
26+
"""
27+
28+
def __init__(self, *args, **kwargs):
29+
# Override target_dtype to bfloat16
30+
super().__init__(
31+
target_dtype="bfloat16",
32+
preserve_weights=FLOAT32_PRESERVED_WEIGHTS,
33+
*args,
34+
**kwargs,
35+
)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""
2+
Dtype generalization pass for float16.
3+
4+
This pass converts float32 tensors to float16.
5+
"""
6+
7+
from graph_net.torch.dtype_gen_passes.dtype_generalization_pass import (
8+
ConcretePass as BaseConcretePass,
9+
)
10+
11+
# Weights that must remain float32 for numerical stability
12+
FLOAT32_PRESERVED_WEIGHTS = {
13+
"running_mean",
14+
"running_var",
15+
"num_batches_tracked",
16+
"bn_parameters_weight",
17+
"bn_parameters_bias",
18+
"ln_parameters_weight",
19+
"ln_parameters_bias",
20+
}
21+
22+
23+
class ConcretePass(BaseConcretePass):
24+
"""
25+
FX Graph pass that converts dtypes to float16.
26+
"""
27+
28+
def __init__(self, *args, **kwargs):
29+
# Override target_dtype to float16
30+
super().__init__(
31+
target_dtype="float16",
32+
preserve_weights=FLOAT32_PRESERVED_WEIGHTS,
33+
*args,
34+
**kwargs,
35+
)

graph_net/torch/dtype_generalizer.py

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -110,20 +110,18 @@ def _test_dtype_passes(
110110
traced_model: Traced FX GraphModule
111111
112112
Returns:
113-
List of pass names that work
113+
List of pass names that work (pass file names without .py extension)
114114
"""
115115
working_passes = []
116116

117117
for dtype in self.dtype_list:
118-
pass_name = "dtype_generalization_pass" # Full pass file name without .py
118+
# Pass name directly corresponds to file name (without .py)
119+
pass_name = f"dtype_generalization_pass_{dtype}"
119120

120121
try:
121-
# Try to apply the pass
122+
# Try to load and apply the pass
122123
dtype_pass_class = get_dtype_generalization_pass(pass_name)
123-
dtype_pass = dtype_pass_class(
124-
target_dtype=dtype,
125-
preserve_weights=FLOAT32_PRESERVED_WEIGHTS,
126-
)
124+
dtype_pass = dtype_pass_class()
127125

128126
# Check if pass is needed
129127
if not dtype_pass.need_rewrite(traced_model):
@@ -135,11 +133,11 @@ def _test_dtype_passes(
135133

136134
# Try to run the modified graph
137135
if self._test_graph_runnable(model_path, gm_copy, dtype):
138-
working_passes.append(f"{pass_name}_{dtype}")
139-
logging.info(f"Pass {pass_name}_{dtype} works for {model_path}")
136+
working_passes.append(pass_name)
137+
logging.info(f"Pass {pass_name} works for {model_path}")
140138

141139
except (RuntimeError, ValueError, TypeError) as e:
142-
logging.warning(f"Pass {pass_name}_{dtype} failed: {e}")
140+
logging.warning(f"Pass {pass_name} failed: {e}")
143141
continue
144142

145143
return working_passes
@@ -302,37 +300,25 @@ def _apply_pass_and_generate(
302300
Args:
303301
model_path: Original model path
304302
traced_model: Original traced model
305-
pass_name: Name of the pass to apply (e.g., "dtype_generalization_pass_float16")
303+
pass_name: Name of the pass file (without .py extension),
304+
e.g., "dtype_generalization_pass_float16"
306305
307306
Returns:
308307
Path to the generated sample directory
309308
"""
310-
# Parse pass name to extract base name and dtype
311-
# Format: "dtype_generalization_pass_float16" or "dtype_generalization_pass_bfloat16"
312-
# The base name "dtype_generalization_pass" corresponds to the file
313-
# dtype_generalization_pass.py, which contains the ConcretePass class.
314-
parts = pass_name.rsplit("_", 1)
315-
if len(parts) != 2:
309+
# Pass name directly corresponds to file name (without .py)
310+
# Extract dtype from pass name for output directory naming
311+
if not pass_name.startswith("dtype_generalization_pass_"):
316312
raise ValueError(
317-
f"Invalid pass name format: {pass_name}. "
313+
f"Invalid pass name: {pass_name}. "
318314
f"Expected format: 'dtype_generalization_pass_<dtype>'"
319315
)
320316

321-
base_name, dtype = parts
322-
323-
# Validate base name
324-
if base_name != "dtype_generalization_pass":
325-
raise ValueError(
326-
f"Unknown pass base name: {base_name}. "
327-
f"Expected: 'dtype_generalization_pass'"
328-
)
317+
dtype = pass_name.replace("dtype_generalization_pass_", "")
329318

330319
# Load and apply the pass
331-
dtype_pass_class = get_dtype_generalization_pass(base_name)
332-
dtype_pass = dtype_pass_class(
333-
target_dtype=dtype,
334-
preserve_weights=FLOAT32_PRESERVED_WEIGHTS,
335-
)
320+
dtype_pass_class = get_dtype_generalization_pass(pass_name)
321+
dtype_pass = dtype_pass_class()
336322

337323
gm_copy = copy.deepcopy(traced_model)
338324
gm_modified = dtype_pass.rewrite(gm_copy)

0 commit comments

Comments
 (0)