Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
636284e
init 'symbolic_dimension_reifier' field in graph_net.json
lixinqi Dec 5, 2025
9cb5f37
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Dec 8, 2025
47a56f4
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Dec 8, 2025
21a9a66
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Dec 10, 2025
74f4036
remove unused files
lixinqi Dec 10, 2025
d074d2c
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Dec 10, 2025
57d92b5
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Dec 10, 2025
d9a98d9
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Dec 17, 2025
372ce6e
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Dec 17, 2025
2d26424
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Dec 18, 2025
e5088b3
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Dec 18, 2025
b2e46e8
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Dec 21, 2025
103873e
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Dec 22, 2025
54244cd
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Dec 23, 2025
b5fb059
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Dec 23, 2025
9fccdac
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Dec 24, 2025
547fd94
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Dec 25, 2025
2001f81
backup code
lixinqi Dec 31, 2025
9223a93
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into shape…
lixinqi Dec 31, 2025
c21c6ee
add sample_pass/group_ranges_from_subgraph_sources.py
lixinqi Jan 1, 2026
204b53e
add sample pass torch/sample_pass/subgraph_input_producer_indexes_gen…
lixinqi Jan 1, 2026
9678c73
add sample pass torch/sample_pass/shape_propagator.py
lixinqi Jan 1, 2026
21ccfd0
save subgraph relative model paths
lixinqi Jan 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions graph_net/apply_sample_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,32 @@ def _get_handler(args):
def main(args):
handler = _get_handler(args)
if args.model_path is not None:
assert not hasattr(handler, "BEGIN")
assert not hasattr(handler, "END")
handle_model_path(handler, args.model_path)
elif args.use_subprocess:
assert not hasattr(handler, "BEGIN")
assert not hasattr(handler, "END")
handle_model_path_list_in_subprocess(args)
else:
handle_model_path_list_in_current_process(handler, args)


def handle_model_path_list_in_current_process(handler, args):
for model_path in _get_model_path_list(args):
rel_model_paths = list(_get_model_path_list(args))
if hasattr(handler, "BEGIN"):
handler.BEGIN(rel_model_paths)
for rel_model_path in rel_model_paths:
try:
handle_model_path(handler, model_path)
handle_model_path(handler, rel_model_path)
except KeyboardInterrupt:
print("KeyboardInterrupt")
return
except Exception:
print("------------[apply_sample_pass failed]------------", flush=True)
traceback.print_exc()
if hasattr(handler, "END"):
handler.END(rel_model_paths)


def handle_model_path_list_in_subprocess(args):
Expand Down
16 changes: 16 additions & 0 deletions graph_net/bash_templates/apply_sample_pass_sh.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash

GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")

python3 -m graph_net.apply_sample_pass \
--model-path-list "customize_your_model_path_list" \
--sample-pass-file-path "$GRAPH_NET_ROOT/graph_net/customize_your_sample_pass.py" \
--sample-pass-class-name customize_your_class_name \
--sample-pass-config $(base64 -w 0 <<EOF
{
"resume": true,
"model_path_prefix": "/customize_your_model_path_prefix",
"output_dir": "/customize_your_output_file"
}
EOF
)
102 changes: 102 additions & 0 deletions graph_net/sample_pass/group_ranges_from_subgraph_sources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from graph_net.sample_pass.sample_pass import SamplePass
from pathlib import Path
import json


class GroupRangesFromSubgraphSources(SamplePass):
def __init__(self, config=None):
super().__init__(config)
self.original_graph_rel_model_path2ranges: dict[str, list[(int, int)]] = {}
self.original_graph_rel_model_path2subgraph_rel_model_paths: dict[
str, list[str]
] = {}

def declare_config(
self,
subgraph_model_path_prefix: str,
output_dir: str,
subgraph_sources_json_file_name: str = "subgraph_sources.json",
output_json_file_name: str = "grouped_ranges_from_subgraph_sources.json",
output_json_key: str = "grouped_ranges_from_subgraph_sources",
output_json_subgraph_rel_model_path_key: str = "subgraph_rel_model_paths",
):
pass

def __call__(self, subgraph_rel_model_path: str):
model_path = (
Path(self.config["subgraph_model_path_prefix"])
/ subgraph_rel_model_path
/ self.config["subgraph_sources_json_file_name"]
)
subgraph_sources = json.load(open(model_path))
for original_graph_rel_model_path, subgraph_ranges in subgraph_sources.items():
self._collect_original_graph_rel_model_path2ranges(
original_graph_rel_model_path, subgraph_ranges
)
self._collect_original_graph_rel_model_path2subgraph_rel_model_path(
original_graph_rel_model_path,
[subgraph_rel_model_path] * len(subgraph_ranges),
)

def _collect_original_graph_rel_model_path2subgraph_rel_model_path(
self,
original_graph_rel_model_path: str,
subgraph_rel_model_paths: list[str],
):
old = self.original_graph_rel_model_path2subgraph_rel_model_paths.get(
original_graph_rel_model_path, []
)
self.original_graph_rel_model_path2subgraph_rel_model_paths[
original_graph_rel_model_path
] = [
*old,
*subgraph_rel_model_paths,
]

def _collect_original_graph_rel_model_path2ranges(
self, original_graph_rel_model_path, subgraph_ranges
):
old_ranges = self.original_graph_rel_model_path2ranges.get(
original_graph_rel_model_path, []
)
self.original_graph_rel_model_path2ranges[original_graph_rel_model_path] = [
*old_ranges,
*subgraph_ranges,
]

def END(self, rel_model_paths: list[str]):
for (
original_graph_rel_model_path,
subgraph_ranges,
) in self.original_graph_rel_model_path2ranges.items():
subgraph_rel_model_paths = (
self.original_graph_rel_model_path2subgraph_rel_model_paths[
original_graph_rel_model_path
]
)
self._save_json(
original_graph_rel_model_path, subgraph_ranges, subgraph_rel_model_paths
)

def _save_json(
self, original_graph_rel_model_path, subgraph_ranges, subgraph_rel_model_paths
):
model_dir = Path(self.config["output_dir"]) / original_graph_rel_model_path
model_dir.mkdir(parents=True, exist_ok=True)
ranges_json = self._get_ranges_json(subgraph_ranges)
paths_json = self._get_paths_json(subgraph_rel_model_paths)
json_obj = {**ranges_json, **paths_json}
json_str = json.dumps(json_obj, indent=4)
(model_dir / self.config["output_json_file_name"]).write_text(json_str)

def _get_paths_json(self, subgraph_rel_model_paths: list[str]):
json_obj = {
self.config[
"output_json_subgraph_rel_model_path_key"
]: subgraph_rel_model_paths
}
return json_obj

def _get_ranges_json(self, subgraph_ranges: list[(int, int)]):
json_obj = {self.config["output_json_key"]: subgraph_ranges}
return json_obj
6 changes: 3 additions & 3 deletions graph_net/sample_pass/sample_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _make_config_by_config_declare(self, config):
class_name = type(self).__name__
assert name in mut_config, f"{name=} {class_name=}"

def get_extra_config_fields():
def get_undefined_config_fields():
return set(name for name, _ in mut_config.items()) - set(
name for name, _ in sig.parameters.items()
)
Expand All @@ -71,10 +71,10 @@ def get_extra_config_fields():
for _, param in sig.parameters.items()
)
if no_varadic_keyword:
no_extra_config_fields = all(
no_undefined_config_fields = all(
name in sig.parameters for name, _ in mut_config.items()
)
assert no_extra_config_fields, f"{get_extra_config_fields()=}"
assert no_undefined_config_fields, f"{get_undefined_config_fields()=}"
return mut_config

def _complete_default(self, name, param, mut_config):
Expand Down
15 changes: 15 additions & 0 deletions graph_net/test/group_ranges_from_subgraph_sources_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash

GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")

python3 -m graph_net.apply_sample_pass \
--model-path-list "$GRAPH_NET_ROOT/graph_net/test/workspace_group_ranges_from_subgraph_sources/sample_list.txt" \
--sample-pass-file-path "$GRAPH_NET_ROOT/graph_net/sample_pass/group_ranges_from_subgraph_sources.py" \
--sample-pass-class-name GroupRangesFromSubgraphSources \
--sample-pass-config $(base64 -w 0 <<EOF
{
"subgraph_model_path_prefix": "$GRAPH_NET_ROOT/graph_net/test/workspace_group_ranges_from_subgraph_sources",
"output_dir": "/tmp/workspace_group_ranges_from_subgraph_sources"
}
EOF
)
16 changes: 16 additions & 0 deletions graph_net/test/shape_propagator_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash

GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")

python3 -m graph_net.apply_sample_pass \
--model-path-list "$GRAPH_NET_ROOT/graph_net/test/workspace_subgraph_input_producer_indexes_generator/sample_list.txt" \
--sample-pass-file-path "$GRAPH_NET_ROOT/graph_net/torch/sample_pass/shape_propagator.py" \
--sample-pass-class-name ShapePropagator \
--sample-pass-config $(base64 -w 0 <<EOF
{
"resume": false,
"model_path_prefix": "$GRAPH_NET_ROOT",
"output_dir": "/tmp/workspace_shape_propagator"
}
EOF
)
19 changes: 19 additions & 0 deletions graph_net/test/subgraph_input_producer_indexes_generator_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash

GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")

python3 -m graph_net.apply_sample_pass \
--model-path-list "$GRAPH_NET_ROOT/graph_net/test/workspace_subgraph_input_producer_indexes_generator/sample_list.txt" \
--sample-pass-file-path "$GRAPH_NET_ROOT/graph_net/torch/sample_pass/subgraph_input_producer_indexes_generator.py" \
--sample-pass-class-name SubgraphInputProducerIndexesGenerator \
--sample-pass-config $(base64 -w 0 <<EOF
{
"resume": false,
"model_path_prefix": "$GRAPH_NET_ROOT",
"subgraph_ranges_json_root": "$GRAPH_NET_ROOT/graph_net/test/workspace_subgraph_input_producer_indexes_generator",
"subgraph_ranges_json_file_name": "grouped_ranges_from_subgraph_sources.json",
"subgraph_ranges_json_key": "grouped_ranges_from_subgraph_sources",
"output_dir": "/tmp/workspace_subgraph_input_producer_indexes_generator"
}
EOF
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
samples/timm/mobilenetv4_conv_aa_large.e230_r384_in12k/_decomposed/mobilenetv4_conv_aa_large.e230_r384_in12k_start33_end38_6
samples/timm/mobilenetv4_conv_aa_large.e230_r384_in12k/_decomposed/mobilenetv4_conv_aa_large.e230_r384_in12k_start55_end60_10
samples/timm/mobilenetv4_conv_aa_large.e230_r384_in12k/_decomposed/mobilenetv4_conv_aa_large.e230_r384_in12k_start60_end64_11
samples/timm/convnextv2_base.fcmae_ft_in1k/_decomposed/convnextv2_base.fcmae_ft_in1k_start287_end297_32
samples/timm/convnextv2_base.fcmae_ft_in1k/_decomposed/convnextv2_base.fcmae_ft_in1k_start401_end411_44
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"samples/timm/convnextv2_base.fcmae_ft_in1k": [
[
287,
297
]
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"samples/timm/convnextv2_base.fcmae_ft_in1k": [
[
401,
411
]
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"samples/timm/mobilenetv4_conv_aa_large.e230_r384_in12k": [
[
33,
38
]
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"samples/timm/mobilenetv4_conv_aa_large.e230_r384_in12k": [
[
55,
60
]
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"samples/timm/mobilenetv4_conv_aa_large.e230_r384_in12k": [
[
60,
64
]
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
samples/timm/mobilenetv4_conv_aa_large.e230_r384_in12k
samples/timm/convnextv2_base.fcmae_ft_in1k
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"grouped_ranges_from_subgraph_sources": [
[
287,
297
],
[
401,
411
]
],
"subgraph_rel_model_paths": [
"samples/timm/convnextv2_base.fcmae_ft_in1k/_decomposed/convnextv2_base.fcmae_ft_in1k_start287_end297_32",
"samples/timm/convnextv2_base.fcmae_ft_in1k/_decomposed/convnextv2_base.fcmae_ft_in1k_start401_end411_44"
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"grouped_ranges_from_subgraph_sources": [
[
33,
38
],
[
55,
60
],
[
60,
64
]
],
"subgraph_rel_model_paths": [
"samples/timm/mobilenetv4_conv_aa_large.e230_r384_in12k/_decomposed/mobilenetv4_conv_aa_large.e230_r384_in12k_start33_end38_6",
"samples/timm/mobilenetv4_conv_aa_large.e230_r384_in12k/_decomposed/mobilenetv4_conv_aa_large.e230_r384_in12k_start55_end60_10",
"samples/timm/mobilenetv4_conv_aa_large.e230_r384_in12k/_decomposed/mobilenetv4_conv_aa_large.e230_r384_in12k_start60_end64_11"
]
}
Loading