Skip to content

Commit 7da86c6

Browse files
committed
Fix integer tensor type detection in dtype generalization pass
1 parent 495241d commit 7da86c6

File tree

5 files changed

+136
-95
lines changed

5 files changed

+136
-95
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,33 @@
11
kDimensionGeneralizationPasses = "dimension_generalization_passes"
22
kDataTypeGeneralizationPasses = "data_type_generalization_passes"
3+
4+
import json
5+
from pathlib import Path
6+
from typing import Union
7+
8+
9+
def update_json(json_path: Union[str, Path], updates: dict) -> None:
10+
"""
11+
Atomically update a JSON file with the given updates.
12+
13+
Args:
14+
json_path: Path to the JSON file
15+
updates: Dictionary of key-value pairs to update
16+
"""
17+
json_path = Path(json_path)
18+
19+
# Read existing JSON
20+
if json_path.exists():
21+
with open(json_path, "r") as f:
22+
metadata = json.load(f)
23+
else:
24+
metadata = {}
25+
26+
# Apply updates
27+
metadata.update(updates)
28+
29+
# Atomic write: write to temp file then rename
30+
temp_path = json_path.with_suffix(".json.tmp")
31+
with open(temp_path, "w") as f:
32+
json.dump(metadata, f, indent=4)
33+
temp_path.replace(json_path)

graph_net/test/dtype_gen_test.sh

Lines changed: 27 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -31,27 +31,28 @@ echo ""
3131
echo "[1/2] Testing CV model: timm/resnet18"
3232
config_json_str_init=$(cat <<EOF
3333
{
34-
"decorator_path": "$GRAPH_NET_ROOT/torch/dtype_generalizer.py",
35-
"decorator_class_name": "InitDataTypeGeneralizationPasses",
36-
"decorator_config": {
37-
"dtype_list": ["float16", "bfloat16"]
34+
"handler_path": "$GRAPH_NET_ROOT/torch/dtype_generalizer.py",
35+
"handler_class_name": "InitDataTypeGeneralizationPasses",
36+
"handler_config": {
37+
"dtype_list": ["float16", "bfloat16"],
38+
"model_path_prefix": "$SAMPLES_ROOT"
3839
}
3940
}
4041
EOF
4142
)
4243
CONFIG_INIT=$(echo "$config_json_str_init" | base64 -w 0)
4344

44-
python3 -m graph_net.torch.run_model \
45-
--model-path "$SAMPLES_ROOT/timm/resnet18" \
46-
--decorator-config="$CONFIG_INIT" || echo "Warning: CV model test failed"
45+
python3 -m graph_net.model_path_handler \
46+
--model-path "timm/resnet18" \
47+
--handler-config="$CONFIG_INIT" || echo "Warning: CV model test failed"
4748

4849
echo ""
4950

5051
# Test on an NLP model (BERT-like)
5152
echo "[2/2] Testing NLP model: transformers-auto-model/opus-mt-en-gmw"
52-
python3 -m graph_net.torch.run_model \
53-
--model-path "$SAMPLES_ROOT/transformers-auto-model/opus-mt-en-gmw" \
54-
--decorator-config="$CONFIG_INIT" || echo "Warning: NLP model test failed"
53+
python3 -m graph_net.model_path_handler \
54+
--model-path "transformers-auto-model/opus-mt-en-gmw" \
55+
--handler-config="$CONFIG_INIT" || echo "Warning: NLP model test failed"
5556

5657
echo ""
5758
echo "Step 1 completed. Pass names written to graph_net.json"
@@ -67,69 +68,37 @@ echo ""
6768

6869
config_json_str_apply=$(cat <<EOF
6970
{
70-
"decorator_path": "$GRAPH_NET_ROOT/torch/dtype_generalizer.py",
71-
"decorator_class_name": "ApplyDataTypeGeneralizationPasses",
72-
"decorator_config": {
73-
"output_dir": "$OUTPUT_DIR"
71+
"handler_path": "$GRAPH_NET_ROOT/torch/dtype_generalizer.py",
72+
"handler_class_name": "ApplyDataTypeGeneralizationPasses",
73+
"handler_config": {
74+
"output_dir": "$OUTPUT_DIR",
75+
"model_path_prefix": "$SAMPLES_ROOT",
76+
"model_runnable_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
77+
"model_runnable_predicator_class_name": "RunModelPredicator",
78+
"model_runnable_predicator_config": {
79+
"use_dummy_inputs": true
80+
}
7481
}
7582
}
7683
EOF
7784
)
7885
CONFIG_APPLY=$(echo "$config_json_str_apply" | base64 -w 0)
7986

8087
echo "[1/2] Generating CV samples..."
81-
python3 -m graph_net.torch.run_model \
82-
--model-path "$SAMPLES_ROOT/timm/resnet18" \
83-
--decorator-config="$CONFIG_APPLY" || echo "Warning: CV generation failed"
88+
python3 -m graph_net.model_path_handler \
89+
--model-path "timm/resnet18" \
90+
--handler-config="$CONFIG_APPLY" || echo "Warning: CV generation failed"
8491

8592
echo ""
8693

8794
echo "[2/2] Generating NLP samples..."
88-
python3 -m graph_net.torch.run_model \
89-
--model-path "$SAMPLES_ROOT/transformers-auto-model/opus-mt-en-gmw" \
90-
--decorator-config="$CONFIG_APPLY" || echo "Warning: NLP generation failed"
95+
python3 -m graph_net.model_path_handler \
96+
--model-path "transformers-auto-model/opus-mt-en-gmw" \
97+
--handler-config="$CONFIG_APPLY" || echo "Warning: NLP generation failed"
9198

9299
echo ""
93100
echo "Step 2 completed. Generated samples in: $OUTPUT_DIR"
94-
echo ""
95-
96-
# ============================================
97-
# Verification
98-
# ============================================
99-
echo "=========================================="
100-
echo "Verification"
101-
echo "=========================================="
102-
echo ""
103-
104-
if [ -d "$OUTPUT_DIR" ]; then
105-
echo "Generated samples:"
106-
ls -lh "$OUTPUT_DIR"
107-
echo ""
108-
109-
# Count generated samples
110-
SAMPLE_COUNT=$(find "$OUTPUT_DIR" -mindepth 1 -maxdepth 1 -type d | wc -l)
111-
echo "Total samples generated: $SAMPLE_COUNT"
112-
113-
if [ $SAMPLE_COUNT -gt 0 ]; then
114-
echo ""
115-
echo "✓ Test PASSED: Successfully generated $SAMPLE_COUNT low-precision samples"
116-
echo ""
117-
echo "You can now use these samples for:"
118-
echo " - test_compiler evaluation"
119-
echo " - Agent code generation"
120-
echo " - Performance benchmarking"
121-
else
122-
echo ""
123-
echo "✗ Test WARNING: No samples were generated"
124-
echo " This might be normal if models don't support dtype conversion"
125-
fi
126-
else
127-
echo "✗ Test FAILED: Output directory not created"
128-
exit 1
129-
fi
130-
131101
echo ""
132102
echo "=========================================="
133103
echo "Test Complete"
134104
echo "=========================================="
135-

graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,16 +129,25 @@ def _is_float32_tensor(self, node: fx.Node) -> bool:
129129
return tensor_meta.dtype == torch.float32
130130

131131
# For placeholder and get_attr nodes without metadata,
132-
# conservatively assume they might be float32
133-
# This is safe because:
134-
# 1. .to() on non-float tensors is a no-op for most cases
135-
# 2. Integer tensors (like input_ids) won't be affected
132+
# we need to be conservative and only return True if explicitly float
136133
if node.op in ("placeholder", "get_attr"):
137134
# Check type annotation if available
138135
if node.type is not None:
139-
type_str = str(node.type)
140-
# Only return True if it's explicitly a floating point tensor
141-
if "Tensor" in type_str and "int" not in type_str.lower():
136+
type_str = str(node.type).lower()
137+
138+
# Explicitly check for integer types - these should NOT be converted
139+
integer_types = ["long", "int", "short", "byte", "bool"]
140+
if any(int_type in type_str for int_type in integer_types):
141+
return False
142+
143+
# Only return True if explicitly a floating point tensor
144+
# Check for explicit float types: FloatTensor, float32, float16, etc.
145+
float_indicators = ["float", "double", "half", "bfloat"]
146+
if any(float_indicator in type_str for float_indicator in float_indicators):
142147
return True
148+
149+
# For generic "Tensor" without explicit dtype, be conservative
150+
# Don't assume it's float32 - it might be integer
151+
return False
143152

144153
return False

graph_net/torch/dtype_generalizer.py

Lines changed: 57 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,18 @@
1919

2020
import torch.fx as fx
2121

22-
from graph_net.graph_net_json_file_util import kDataTypeGeneralizationPasses
22+
from graph_net.graph_net_json_file_util import (
23+
kDataTypeGeneralizationPasses,
24+
update_json,
25+
)
2326
from graph_net.torch.constraint_util import RunModelPredicator
2427
from graph_net.torch.fx_graph_cache_util import (
2528
parse_immutable_model_path_into_sole_graph_module,
2629
)
2730
from graph_net.torch.fx_graph_serialize_util import serialize_graph_module_to_str
2831
from graph_net.torch.dtype_gen_passes.pass_mgr import get_dtype_generalization_pass
2932
from graph_net.torch import utils
33+
from graph_net.imp_util import load_module
3034

3135

3236
# Weights that must remain float32 for numerical stability
@@ -51,12 +55,14 @@ class InitDataTypeGeneralizationPasses:
5155
Config format:
5256
{
5357
"dtype_list": ["float16", "bfloat16"],
58+
"model_path_prefix": "",
5459
}
5560
"""
5661

5762
def __init__(self, config: Dict[str, Any]):
5863
self.config = config
5964
self.dtype_list = config.get("dtype_list", ["float16", "bfloat16"])
65+
self.model_path_prefix = config.get("model_path_prefix", "")
6066

6167
# Validate dtypes
6268
valid_dtypes = {"float16", "bfloat16", "float8"}
@@ -71,8 +77,12 @@ def __call__(self, model_path: str) -> None:
7177
Initialize dtype passes for the given model.
7278
7379
Args:
74-
model_path: Path to the model directory
80+
model_path: Path to the model directory (may be relative to model_path_prefix)
7581
"""
82+
# Apply model_path_prefix if provided
83+
if self.model_path_prefix:
84+
model_path = str(Path(self.model_path_prefix) / model_path)
85+
7686
# Parse the computation graph
7787
traced_model = parse_immutable_model_path_into_sole_graph_module(model_path)
7888

@@ -176,18 +186,7 @@ def _save_dtype_pass_names(
176186
model_path: Path to model directory
177187
"""
178188
graph_net_json_path = Path(model_path) / "graph_net.json"
179-
180-
with open(graph_net_json_path, "r") as f:
181-
metadata = json.load(f)
182-
183-
metadata[kDataTypeGeneralizationPasses] = dtype_pass_names
184-
185-
# Atomic write: write to temp file then rename
186-
temp_path = graph_net_json_path.with_suffix(".json.tmp")
187-
with open(temp_path, "w") as f:
188-
json.dump(metadata, f, indent=4)
189-
190-
temp_path.replace(graph_net_json_path)
189+
update_json(graph_net_json_path, {kDataTypeGeneralizationPasses: dtype_pass_names})
191190

192191

193192
class ApplyDataTypeGeneralizationPasses:
@@ -200,6 +199,10 @@ class ApplyDataTypeGeneralizationPasses:
200199
Config format:
201200
{
202201
"output_dir": "/path/to/output",
202+
"model_path_prefix": "",
203+
"model_runnable_predicator_filepath": "...",
204+
"model_runnable_predicator_class_name": "...",
205+
"model_runnable_predicator_config": {...},
203206
}
204207
"""
205208

@@ -208,17 +211,41 @@ def __init__(self, config: Dict[str, Any]):
208211
self.output_dir = config.get("output_dir")
209212
if not self.output_dir:
210213
raise ValueError("output_dir is required in config")
214+
215+
self.model_path_prefix = config.get("model_path_prefix", "")
216+
217+
# model_runnable_predicator is required to ensure generated code is runnable
218+
if "model_runnable_predicator_filepath" not in config:
219+
raise ValueError(
220+
"model_runnable_predicator_filepath is required in config. "
221+
"Generated code must be validated."
222+
)
223+
self.model_runnable_predicator = self._make_model_runnable_predicator(config)
224+
225+
def _make_model_runnable_predicator(self, config: Dict[str, Any]):
226+
"""Create model runnable predicator from config."""
227+
module = load_module(config["model_runnable_predicator_filepath"])
228+
cls = getattr(
229+
module,
230+
config.get("model_runnable_predicator_class_name", "RunModelPredicator"),
231+
)
232+
predicator_config = config.get("model_runnable_predicator_config", {})
233+
return cls(predicator_config)
211234

212235
def __call__(self, model_path: str) -> List[str]:
213236
"""
214237
Apply dtype passes to generate new samples.
215238
216239
Args:
217-
model_path: Path to the original model directory
240+
model_path: Path to the original model directory (may be relative to model_path_prefix)
218241
219242
Returns:
220243
List of generated sample directories
221244
"""
245+
# Apply model_path_prefix if provided
246+
if self.model_path_prefix:
247+
model_path = str(Path(self.model_path_prefix) / model_path)
248+
222249
# Read pass names from graph_net.json
223250
dtype_pass_names = self._read_dtype_pass_names(model_path)
224251

@@ -316,6 +343,13 @@ def _apply_pass_and_generate(
316343
# Update graph_net.json with dtype information
317344
self._update_sample_metadata(output_sample_dir, dtype)
318345

346+
# Validate generated sample (required - generated code must be runnable)
347+
if not self.model_runnable_predicator(str(output_sample_dir)):
348+
raise RuntimeError(
349+
f"Generated sample failed validation: {output_sample_dir}"
350+
)
351+
logging.info(f"Generated sample validated: {output_sample_dir}")
352+
319353
return str(output_sample_dir)
320354

321355
def _update_sample_metadata(self, sample_dir: Path, dtype: str) -> None:
@@ -327,20 +361,14 @@ def _update_sample_metadata(self, sample_dir: Path, dtype: str) -> None:
327361
dtype: Target dtype
328362
"""
329363
graph_net_json_path = sample_dir / "graph_net.json"
330-
331-
with open(graph_net_json_path, "r") as f:
332-
metadata = json.load(f)
333-
334-
# Add dtype information
335-
metadata["dtype"] = dtype
336-
metadata["precision"] = dtype
337-
metadata["generated_from_dtype_generalization"] = True
338-
339-
# Atomic write
340-
temp_path = graph_net_json_path.with_suffix(".json.tmp")
341-
with open(temp_path, "w") as f:
342-
json.dump(metadata, f, indent=4)
343-
temp_path.replace(graph_net_json_path)
364+
update_json(
365+
graph_net_json_path,
366+
{
367+
"dtype": dtype,
368+
"precision": dtype,
369+
"generated_from_dtype_generalization": True,
370+
},
371+
)
344372

345373

346374
class MultiDtypeFilter:

samples/timm/resnet18/graph_net.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,9 @@
33
"num_devices_required": 1,
44
"num_nodes_required": 1,
55
"source": "timm",
6-
"heuristic_tag": "computer_vision"
6+
"heuristic_tag": "computer_vision",
7+
"data_type_generalization_passes": [
8+
"dtype_generalization_pass_float16",
9+
"dtype_generalization_pass_bfloat16"
10+
]
711
}

0 commit comments

Comments
 (0)