Skip to content

Commit 528d46c

Browse files
committed
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
2 parents 16264a7 + de2d8c4 commit 528d46c

12 files changed

+253
-42
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ python -m graph_net.plot_violin \
9696

9797
The scripts are designed to process a file structure as `/benchmark_path/category_name/`, and items on x-axis are identified by name of the sub-directories. After executing, several summary plots of result in categories (model tasks, libraries...) will be exported to `$GRAPH_NET_BENCHMARK_PATH`.
9898

99+
### Hardware Regression Testing
100+
We also provide a two-step workflow that validates compiler correctness and performance against a "golden" reference, which is crucial for hardware-specific testing and regression tracking. Details can be found in this [guide](./docs/hardware_test.md).
101+
99102
### 🧱 Construction & Contribution Guide
100103
Want to understand how GraphNet is built or contribute new samples?
101104
Check out the [Construction Guide](./docs/README_contribute.md) for details on the extraction and validation workflow.

docs/hardware_test.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
## Hardware Regression Testing
2+
### Step 1: Generate Reference Data
3+
First, use `graph_net.paddle.test_reference_device` on a trusted setting (e.g., a specific hardware/compiler version) to generate baseline logs and output files.
4+
```bash
5+
python -m graph_net.paddle.test_reference_device \
6+
--model-path /path/to/all_models/ \
7+
--reference-dir ./gold_reference \
8+
--compiler cinn \
9+
--device cuda
10+
# --reference-dir: (Required) Directory where the output .log (performance/config) and .pdout (output tensors) files will be saved.
11+
# --compiler: Specifies the compiler backend.
12+
```
13+
### Step 2: Run Regression Test
14+
After changing hardware, run the correctness test script. This script reads the reference data, re-runs the models using the exact same configuration, and compares the new results against the "golden" reference.
15+
```bash
16+
python -m graph_net.paddle.test_device_correctness \
17+
--reference-dir ./golden_reference \
18+
--device cuda
19+
```
20+
This script will report any failures (e.g., compilation errors, output mismatches) and print a performance comparison (speedup/slowdown) against the reference log, allowing you to quickly identify regressions.

graph_net/imp_util.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import importlib.util as imp
2+
3+
4+
def load_module(path, name="unamed"):
5+
spec = imp.spec_from_file_location(name, path)
6+
module = imp.module_from_spec(spec)
7+
spec.loader.exec_module(module)
8+
return module
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
5+
6+
# input model path
7+
MODEL_PATH_IN_SAMPLES=/timm/resnet18
8+
read -r -d '' extractor_config_json_str <<EOF
9+
{
10+
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
11+
"custom_extractor_config": {
12+
"output_dir": "/tmp/naive_decompose_workspace",
13+
"split_positions": [8, 16, 32],
14+
"group_head_and_tail": true,
15+
"chain_style": true
16+
}
17+
}
18+
EOF
19+
EXTRACTOR_CONFIG=$(echo $extractor_config_json_str | base64 -w 0)
20+
21+
mkdir -p /tmp/naive_decompose_workspace
22+
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name resnet18 --dump-graph-hash-key --extractor-config=$EXTRACTOR_CONFIG
Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
11
#!/bin/bash
2-
# input model path
3-
MODEL_PATH_IN_SAMPLES=/timm/resnet18
4-
# output model path
5-
OUTPUT_DIR=/tmp/naive_decompose_workspace
62

7-
mkdir -p $OUTPUT_DIR
8-
# extract subgraph 0-8, 8-16
9-
export GRAPH_NET_NAIVE_DECOMPOSER_SPLIT_POS=0,8,16
10-
export GRAPH_NET_EXTRACT_WORKSPACE=$OUTPUT_DIR
113
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
124
os.path.dirname(graph_net.__file__))")
13-
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name resnet18 --dump-graph-hash-key --custom-extractor-path=$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py
5+
6+
# input model path
7+
MODEL_PATH_IN_SAMPLES=/timm/resnet18
8+
read -r -d '' extractor_config_json_str <<EOF
9+
{
10+
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
11+
"custom_extractor_config": {
12+
"output_dir": "/tmp/naive_decompose_workspace",
13+
"split_positions": [8, 32],
14+
"group_head_and_tail": true,
15+
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
16+
"filter_config": {}
17+
}
18+
}
19+
EOF
20+
EXTRACTOR_CONFIG=$(echo $extractor_config_json_str | base64 -w 0)
21+
22+
mkdir -p /tmp/naive_decompose_workspace
23+
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name resnet18 --dump-graph-hash-key --extractor-config=$EXTRACTOR_CONFIG

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,27 @@ def _impl_unstable_to_stable_fftn(self, gm):
126126

127127
return gm
128128

129+
def _impl_unstable_to_stable_special_logit(self, gm):
130+
"""
131+
Convert torch._C._special.special_logit to torch.special.logit
132+
"""
133+
issue_nodes = (
134+
node
135+
for node in gm.graph.nodes
136+
if node.op == "call_function"
137+
if hasattr(node.target, "__module__")
138+
if node.target.__module__ == "torch._C._special"
139+
if hasattr(node.target, "__name__")
140+
if node.target.__name__ == "special_logit"
141+
)
142+
for node in issue_nodes:
143+
node.target = torch.special.logit
144+
145+
# Recompile the graph
146+
gm.recompile()
147+
148+
return gm
149+
129150
def unstable_to_stable(self, gm):
130151
methods = (
131152
name

graph_net/torch/decompose_util.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,13 @@ def convert_to_submodules_graph(
99
original_gm: torch.fx.GraphModule,
1010
split_positions: list[int],
1111
submodule_hook=None,
12-
submodule_name_prefix="extraced_submodule",
12+
submodule_name_prefix="extracted_submodule",
13+
chain_style=False,
1314
group_head_and_tail=True,
1415
):
16+
"""
17+
chain_style=True: decompose original_gm into g0 * g1 * g2 * g3
18+
"""
1519
original_gm = copy.deepcopy(original_gm)
1620
num_placeholders = len(
1721
[node for node in original_gm.graph.nodes if node.op == "placeholder"]
@@ -68,14 +72,22 @@ def get_end_node_idx(range_idx):
6872
return i + 1
6973
raise NotImplementedError("Dead code.")
7074

75+
def print_submodule_call(prompt, gm):
76+
submodule_call_stmts = [
77+
stmt for stmt in gm.code.split("\n") if "self.extracted_submodule" in stmt
78+
]
79+
print(f"{prompt} ", submodule_call_stmts)
80+
7181
for range_idx in range(len(range_idx2submodule_body_nodes)):
7282
(
7383
submodule_input_nodes,
7484
submodule_output_nodes,
85+
identity_nodes,
7586
) = _get_submodule_inputs_and_outputs(
7687
original_gm=original_gm,
7788
start_node_idx=get_start_node_idx(range_idx),
7889
end_node_idx=get_end_node_idx(range_idx),
90+
chain_style=chain_style,
7991
)
8092

8193
def get_input_nodes(range_idx):
@@ -130,15 +142,22 @@ def get_output_nodes(range_idx):
130142
prev_node = new_output_node
131143

132144
# Replace all use of outputs
145+
identity_node_set = set(identity_nodes)
133146
for original_output in get_output_nodes(range_idx):
134-
original_output.replace_all_uses_with(node_map[original_output])
147+
if original_output not in identity_node_set:
148+
original_output.replace_all_uses_with(node_map[original_output])
135149

136150
# Erase old nodes
137151
for node in reversed(get_body_nodes(range_idx)):
138152
original_gm.graph.erase_node(node)
153+
# print_submodule_call("(fx) after Erase old nodes", original_gm)
154+
155+
# print_submodule_call("(fx) before recompile", original_gm)
139156

140157
original_gm.recompile()
141158

159+
# print_submodule_call("(fx) after recompile", original_gm)
160+
142161
return original_gm
143162

144163

@@ -147,7 +166,7 @@ def fold_range_to_submodule(
147166
start_node_idx: int,
148167
end_node_idx: int,
149168
submodule_hook=None,
150-
submodule_name="extraced_submodule",
169+
submodule_name="extracted_submodule",
151170
group_head_and_tail=True,
152171
):
153172
return convert_to_submodules_graph(
@@ -170,6 +189,7 @@ def _get_submodule_inputs_and_outputs(
170189
original_gm: torch.fx.GraphModule,
171190
start_node_idx: int,
172191
end_node_idx: int,
192+
chain_style=False,
173193
):
174194
count_ctx = NodeProducedOrConsumedCountCtx(
175195
defaultdict(int),
@@ -179,7 +199,11 @@ def _get_submodule_inputs_and_outputs(
179199
node_list = list(original_gm.graph.nodes)
180200

181201
def get_related_node(node):
182-
yield from node.args
202+
for arg in node.args:
203+
if isinstance(arg, tuple):
204+
yield from arg
205+
else:
206+
yield arg
183207
yield node
184208

185209
for node in node_list[0:start_node_idx]:
@@ -200,13 +224,32 @@ def get_related_node(node):
200224
if count_ctx.node2before_input[node] > 0
201225
if count_ctx.node2body[node] > 0
202226
]
203-
204227
output_nodes = [
205228
node
206229
for node in node_list
207230
if not (count_ctx.node2before_input[node] > 0)
208231
if count_ctx.node2body[node] > 0
209232
if count_ctx.node2after_output[node] > 0
210233
]
211-
212-
return input_nodes, output_nodes
234+
if not chain_style:
235+
identity_nodes = []
236+
else:
237+
identity_nodes = [
238+
node
239+
for node in node_list
240+
if count_ctx.node2before_input[node] > 0
241+
if count_ctx.node2body[node] == 0
242+
if count_ctx.node2after_output[node] > 0
243+
][:1]
244+
input_nodes_set = set(input_nodes)
245+
input_nodes = [
246+
*input_nodes,
247+
*[node for node in identity_nodes if node not in input_nodes_set],
248+
]
249+
output_nodes_set = set(output_nodes)
250+
output_nodes = [
251+
*output_nodes,
252+
*[node for node in identity_nodes if node not in output_nodes_set],
253+
]
254+
255+
return input_nodes, output_nodes, identity_nodes

graph_net/torch/extractor.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,23 @@
1515

1616
class GraphExtractor:
1717
def __init__(
18-
self, name, dynamic, mut_graph_codes=None, placeholder_auto_rename=False
18+
self,
19+
name,
20+
dynamic,
21+
mut_graph_codes=None,
22+
placeholder_auto_rename=False,
23+
workspace_path=None,
1924
):
2025
self.subgraph_counter = 0
2126
self.name = name
2227
self.dynamic = dynamic
2328
self.mut_graph_codes = mut_graph_codes
2429
self.placeholder_auto_rename = placeholder_auto_rename
25-
self.workspace_path = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE")
30+
self.workspace_path = (
31+
workspace_path
32+
if workspace_path is not None
33+
else os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE")
34+
)
2635
if not self.workspace_path:
2736
raise EnvironmentError(
2837
"Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set."
@@ -130,7 +139,7 @@ def extract(
130139
dynamic=True,
131140
mut_graph_codes=None,
132141
placeholder_auto_rename=False,
133-
custom_extractor_path=None,
142+
extractor_config: dict = None,
134143
):
135144
"""
136145
Extract computation graphs from PyTorch nn.Module.
@@ -200,19 +209,24 @@ def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor):
200209
>>>
201210
"""
202211

203-
def get_graph_extractor_cls():
212+
extractor_config = make_extractor_config(extractor_config)
213+
214+
def get_graph_extractor_maker():
215+
custom_extractor_path = extractor_config["custom_extractor_path"]
216+
custom_extractor_config = extractor_config["custom_extractor_config"]
204217
if custom_extractor_path is None:
205218
return GraphExtractor
206219
import importlib.util as imp
207220

208221
spec = imp.spec_from_file_location("graph_extractor", custom_extractor_path)
209222
graph_extractor = imp.module_from_spec(spec)
210223
spec.loader.exec_module(graph_extractor)
211-
return graph_extractor.GraphExtractor
224+
cls = graph_extractor.GraphExtractor
225+
return lambda *args, **kwargs: cls(custom_extractor_config, *args, **kwargs)
212226

213227
def wrapper(model: torch.nn.Module):
214228
assert isinstance(model, torch.nn.Module), f"{type(model)=}"
215-
extractor = get_graph_extractor_cls()(
229+
extractor = get_graph_extractor_maker()(
216230
name, dynamic, mut_graph_codes, placeholder_auto_rename
217231
)
218232
# return torch.compile(backend=extractor, dynamic=dynamic)
@@ -236,3 +250,18 @@ def decorator_or_wrapper(obj):
236250
)
237251

238252
return decorator_or_wrapper
253+
254+
255+
def make_extractor_config(extractor_config):
256+
kwargs = extractor_config if extractor_config is not None else {}
257+
return make_extractor_config_impl(**kwargs)
258+
259+
260+
def make_extractor_config_impl(
261+
custom_extractor_path: str = None, custom_extractor_config: dict = None
262+
):
263+
config = custom_extractor_config if custom_extractor_config is not None else {}
264+
return {
265+
"custom_extractor_path": custom_extractor_path,
266+
"custom_extractor_config": config,
267+
}

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
2424
(r"torch\._C\._fft\.fft_irfft\(", "torch.fft.irfft("),
2525
(r"torch\._C\._fft\.fft_rfft\(", "torch.fft.rfft("),
2626
(r"torch\._C\._fft\.fft_fftn\(", "torch.fft.fftn("),
27+
(r"torch\._C\._special\.special_logit\(", "torch.special.logit("),
2728
# Add new rules to this list as needed
2829
]
2930
for pattern, repl in replacements:

0 commit comments

Comments
 (0)