Skip to content

Commit 36e69a6

Browse files
committed
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into device_rewrite_sample_pass
2 parents 9069e4d + b33b2d7 commit 36e69a6

File tree

11 files changed

+878
-5
lines changed

11 files changed

+878
-5
lines changed
Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,61 @@
1-
from pathlib import Path
21
import json
2+
from pathlib import Path
33

44
kDimensionGeneralizationPasses = "dimension_generalization_passes"
5+
kDataTypeGeneralizationPasses = "data_type_generalization_passes"
56
kSymbolicDimensionReifier = "symbolic_dimension_reifier"
67

8+
# Fields for dtype generalization metadata
9+
kDtypeGeneralizationTargetDtype = "dtype_generalization_target_dtype"
10+
kDtypeGeneralizationPrecision = "dtype_generalization_precision"
11+
kDtypeGeneralizationGenerated = "dtype_generalization_generated"
12+
713

814
def read_json(model_path):
15+
"""
16+
Read JSON from graph_net.json file.
17+
18+
Args:
19+
model_path: Path to model directory
20+
21+
Returns:
22+
Dictionary containing JSON data
23+
"""
924
graph_net_json_file_path = Path(f"{model_path}/graph_net.json")
1025
return json.loads(graph_net_json_file_path.read_text())
1126

1227

1328
def update_json(model_path, field, value):
14-
graph_net_json_file_path = Path(f"{model_path}/graph_net.json")
15-
graph_net_json = json.loads(graph_net_json_file_path.read_text())
29+
"""
30+
Update a single field in graph_net.json.
31+
32+
Args:
33+
model_path: Path to model directory or graph_net.json file
34+
field: Field name to update
35+
value: Value to set
36+
"""
37+
if isinstance(model_path, (str, Path)):
38+
model_path = Path(model_path)
39+
# If it's a file path, use it directly; otherwise assume it's a directory
40+
if model_path.suffix == ".json":
41+
graph_net_json_file_path = model_path
42+
else:
43+
graph_net_json_file_path = model_path / "graph_net.json"
44+
else:
45+
graph_net_json_file_path = Path(f"{model_path}/graph_net.json")
46+
47+
# Read existing JSON
48+
if graph_net_json_file_path.exists():
49+
with open(graph_net_json_file_path, "r") as f:
50+
graph_net_json = json.load(f)
51+
else:
52+
graph_net_json = {}
53+
54+
# Update field
1655
graph_net_json[field] = value
17-
graph_net_json_file_path.write_text(json.dumps(graph_net_json, indent=4))
56+
57+
# Atomic write: write to temp file then rename
58+
temp_path = graph_net_json_file_path.with_suffix(".json.tmp")
59+
with open(temp_path, "w") as f:
60+
json.dump(graph_net_json, f, indent=4)
61+
temp_path.replace(graph_net_json_file_path)

graph_net/test/dtype_gen_test.sh

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
5+
SAMPLES_ROOT="$GRAPH_NET_ROOT/../samples"
6+
OUTPUT_DIR="/tmp/dtype_gen_samples"
7+
mkdir -p "$OUTPUT_DIR"
8+
9+
# Step 1: Initialize dtype generalization passes
10+
config_json_str_init=$(cat <<EOF
11+
{
12+
"handler_path": "$GRAPH_NET_ROOT/torch/dtype_generalizer.py",
13+
"handler_class_name": "InitDataTypeGeneralizationPasses",
14+
"handler_config": {
15+
"dtype_list": ["float16", "bfloat16"],
16+
"model_path_prefix": "$SAMPLES_ROOT"
17+
}
18+
}
19+
EOF
20+
)
21+
CONFIG_INIT=$(echo "$config_json_str_init" | base64 -w 0)
22+
23+
python3 -m graph_net.model_path_handler --model-path "timm/resnet18" --handler-config=$CONFIG_INIT
24+
python3 -m graph_net.model_path_handler --model-path "transformers-auto-model/opus-mt-en-gmw" --handler-config=$CONFIG_INIT
25+
26+
# Step 2: Apply passes to generate samples
27+
config_json_str_apply=$(cat <<EOF
28+
{
29+
"handler_path": "$GRAPH_NET_ROOT/torch/dtype_generalizer.py",
30+
"handler_class_name": "ApplyDataTypeGeneralizationPasses",
31+
"handler_config": {
32+
"output_dir": "$OUTPUT_DIR",
33+
"model_path_prefix": "$SAMPLES_ROOT",
34+
"model_runnable_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
35+
"model_runnable_predicator_class_name": "RunModelPredicator",
36+
"model_runnable_predicator_config": {
37+
"use_dummy_inputs": true
38+
}
39+
}
40+
}
41+
EOF
42+
)
43+
CONFIG_APPLY=$(echo "$config_json_str_apply" | base64 -w 0)
44+
45+
python3 -m graph_net.model_path_handler --model-path "timm/resnet18" --handler-config=$CONFIG_APPLY
46+
python3 -m graph_net.model_path_handler --model-path "transformers-auto-model/opus-mt-en-gmw" --handler-config=$CONFIG_APPLY
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from graph_net.torch.dtype_gen_passes.pass_base import (
2+
DtypeGeneralizationPass as DtypeGeneralizationPass,
3+
)
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""
2+
Concrete implementation of dtype generalization pass.
3+
4+
This pass converts tensor dtypes in FX Graph by:
5+
1. Converting placeholder nodes (inputs) to target dtype
6+
2. Converting get_attr nodes (weights) to target dtype, except preserved weights
7+
3. Inserting .to(dtype) calls where needed
8+
"""
9+
10+
import torch
11+
import torch.fx as fx
12+
from graph_net.torch.dtype_gen_passes.pass_base import DtypeGeneralizationPass
13+
14+
15+
class ConcretePass(DtypeGeneralizationPass):
16+
"""
17+
FX Graph pass that converts dtypes of tensors.
18+
19+
This pass modifies the graph to:
20+
- Convert input tensors to target dtype
21+
- Convert weight tensors to target dtype (except preserved weights)
22+
- Insert dtype conversion nodes where necessary
23+
"""
24+
25+
def get_pass_name(self) -> str:
26+
return f"dtype_generalization_{self.target_dtype}"
27+
28+
def need_rewrite(self, gm: fx.GraphModule) -> bool:
29+
"""
30+
Check if graph has float32 tensors that need conversion.
31+
"""
32+
for node in gm.graph.nodes:
33+
if self._node_need_rewrite(node):
34+
return True
35+
return False
36+
37+
def _node_need_rewrite(self, node: fx.Node) -> bool:
38+
"""
39+
Check if a specific node needs dtype conversion.
40+
41+
Args:
42+
node: FX Node to check
43+
44+
Returns:
45+
True if node should be rewritten
46+
"""
47+
# Check placeholder nodes (inputs)
48+
if node.op == "placeholder":
49+
return self._is_float32_tensor(node)
50+
51+
# Check get_attr nodes (weights)
52+
if node.op == "get_attr":
53+
if self._is_float32_tensor(node):
54+
# Only rewrite if not in preserve list
55+
attr_name = str(node.target)
56+
return not self.should_preserve_weight(attr_name)
57+
58+
return False
59+
60+
def rewrite(self, gm: fx.GraphModule) -> fx.GraphModule:
61+
"""
62+
Rewrite the graph to convert dtypes.
63+
64+
Strategy:
65+
1. For each placeholder (input), insert .to(target_dtype) after it
66+
2. For each get_attr (weight), insert .to(target_dtype) if not preserved
67+
3. Update the graph and recompile
68+
"""
69+
new_graph = fx.Graph()
70+
val_map = {}
71+
72+
def create_placeholder(node: fx.Node) -> fx.Node:
73+
"""Create a placeholder node with dtype conversion if needed."""
74+
new_node = new_graph.node_copy(node, lambda x: val_map.get(x, x))
75+
if self._is_float32_tensor(node):
76+
return new_graph.call_method("to", args=(new_node, self.torch_dtype))
77+
return new_node
78+
79+
def create_get_attr(node: fx.Node) -> fx.Node:
80+
"""Create a get_attr node with dtype conversion if needed."""
81+
new_node = new_graph.node_copy(node, lambda x: val_map.get(x, x))
82+
attr_name = str(node.target)
83+
if self._is_float32_tensor(node) and not self.should_preserve_weight(
84+
attr_name
85+
):
86+
return new_graph.call_method("to", args=(new_node, self.torch_dtype))
87+
return new_node
88+
89+
for node in gm.graph.nodes:
90+
if node.op == "placeholder":
91+
val_map[node] = create_placeholder(node)
92+
elif node.op == "get_attr":
93+
val_map[node] = create_get_attr(node)
94+
else:
95+
new_node = new_graph.node_copy(node, lambda x: val_map.get(x, x))
96+
val_map[node] = new_node
97+
98+
# Replace the graph
99+
gm.graph = new_graph
100+
gm.recompile()
101+
102+
return gm
103+
104+
def _is_float32_tensor(self, node: fx.Node) -> bool:
105+
"""
106+
Check if a node represents a float32 tensor.
107+
108+
Args:
109+
node: FX Node to check
110+
111+
Returns:
112+
True if node is a float32 tensor
113+
"""
114+
# Check tensor_meta if available (most reliable)
115+
if "tensor_meta" in node.meta:
116+
tensor_meta = node.meta["tensor_meta"]
117+
if hasattr(tensor_meta, "dtype"):
118+
return tensor_meta.dtype == torch.float32
119+
120+
# For placeholder and get_attr nodes without metadata,
121+
# we need to be conservative and only return True if explicitly float
122+
if node.op in ("placeholder", "get_attr"):
123+
# Check type annotation if available
124+
if node.type is not None:
125+
type_str = str(node.type).lower()
126+
127+
# Explicitly check for integer types - these should NOT be converted
128+
integer_types = ["long", "int", "short", "byte", "bool"]
129+
if any(int_type in type_str for int_type in integer_types):
130+
return False
131+
132+
# Only return True if explicitly a floating point tensor
133+
# Check for explicit float types: FloatTensor, float32, float16, etc.
134+
float_indicators = ["float", "double", "half", "bfloat"]
135+
if any(
136+
float_indicator in type_str for float_indicator in float_indicators
137+
):
138+
return True
139+
140+
# For generic "Tensor" without explicit dtype, be conservative
141+
# Don't assume it's float32 - it might be integer
142+
return False
143+
144+
return False
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""
2+
Dtype generalization pass for bfloat16.
3+
4+
This pass converts float32 tensors to bfloat16.
5+
"""
6+
7+
from graph_net.torch.dtype_gen_passes.dtype_generalization_pass import (
8+
ConcretePass as BaseConcretePass,
9+
)
10+
11+
# Weights that must remain float32 for numerical stability
12+
FLOAT32_PRESERVED_WEIGHTS = {
13+
"running_mean",
14+
"running_var",
15+
"num_batches_tracked",
16+
"bn_parameters_weight",
17+
"bn_parameters_bias",
18+
"ln_parameters_weight",
19+
"ln_parameters_bias",
20+
}
21+
22+
23+
class ConcretePass(BaseConcretePass):
24+
"""
25+
FX Graph pass that converts dtypes to bfloat16.
26+
"""
27+
28+
def __init__(self, *args, **kwargs):
29+
# Override target_dtype to bfloat16
30+
super().__init__(
31+
target_dtype="bfloat16",
32+
preserve_weights=FLOAT32_PRESERVED_WEIGHTS,
33+
*args,
34+
**kwargs,
35+
)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""
2+
Dtype generalization pass for float16.
3+
4+
This pass converts float32 tensors to float16.
5+
"""
6+
7+
from graph_net.torch.dtype_gen_passes.dtype_generalization_pass import (
8+
ConcretePass as BaseConcretePass,
9+
)
10+
11+
# Weights that must remain float32 for numerical stability
12+
FLOAT32_PRESERVED_WEIGHTS = {
13+
"running_mean",
14+
"running_var",
15+
"num_batches_tracked",
16+
"bn_parameters_weight",
17+
"bn_parameters_bias",
18+
"ln_parameters_weight",
19+
"ln_parameters_bias",
20+
}
21+
22+
23+
class ConcretePass(BaseConcretePass):
24+
"""
25+
FX Graph pass that converts dtypes to float16.
26+
"""
27+
28+
def __init__(self, *args, **kwargs):
29+
# Override target_dtype to float16
30+
super().__init__(
31+
target_dtype="float16",
32+
preserve_weights=FLOAT32_PRESERVED_WEIGHTS,
33+
*args,
34+
**kwargs,
35+
)

0 commit comments

Comments
 (0)