Skip to content

Commit e5088b3

Browse files
committed
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
2 parents 2d26424 + 9a55a74 commit e5088b3

File tree

7 files changed

+299
-268
lines changed

7 files changed

+299
-268
lines changed

graph_net/sample_pass/sample_pass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ def _make_config_by_config_declare(self, config):
5555
sig = inspect.signature(self.declare_config)
5656
mut_config = copy.deepcopy(config)
5757
for name, param in sig.parameters.items():
58-
self._complete_default(name, param, mut_config)
58+
if name not in mut_config:
59+
self._complete_default(name, param, mut_config)
5960
class_name = type(self).__name__
6061
assert name in mut_config, f"{name=} {class_name=}"
6162

graph_net/test/test_agent_unittest_generator.sh renamed to graph_net/test/agent_unittest_generator_test.sh

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
#!/usr/bin/env bash
2-
# set -euo pipefail
3-
4-
# Smoke tests for AgentUnittestGenerator using model_path_handler + sample pass.
52

63
ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)"
74
GRAPH_NET_ROOT=$(python -c "import graph_net, os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
8-
HANDLER_PATH="$GRAPH_NET_ROOT/graph_net/torch/sample_passes/agent_unittest_generator.py"
5+
96
MODEL_PATH_PREFIX="$ROOT_DIR"
10-
OUTPUT_DIR="$ROOT_DIR"
7+
OUTPUT_DIR="/tmp/agent_unittests"
118

129
HANDLER_CONFIG=$(base64 -w 0 <<EOF
1310
{
14-
"handler_path": "$HANDLER_PATH",
11+
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/sample_passes/agent_unittest_generator.py",
1512
"handler_class_name": "AgentUnittestGeneratorPass",
1613
"handler_config": {
1714
"model_path_prefix": "$MODEL_PATH_PREFIX",
1815
"output_dir": "$OUTPUT_DIR",
19-
"force_device": "auto",
20-
"use_dummy_inputs": false
16+
"device": "auto",
17+
"generate_main": true,
18+
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/graph_net/torch/constraint_util.py",
19+
"data_input_predicator_class_name": "NaiveDataInputPredicator"
2120
}
2221
}
2322
EOF

graph_net/torch/constraint_util.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@ def __call__(self, model_path, input_var_name: str) -> bool:
1616
)
1717

1818

19+
class RenamedDataInputPredicator:
20+
def __init__(self, config):
21+
self.config = config
22+
23+
def __call__(self, model_path, input_var_name: str) -> bool:
24+
return input_var_name.startswith("in_")
25+
26+
1927
class ModelRunnablePredicator:
2028
def __init__(self, config):
2129
if config is None:

graph_net/torch/fx_graph_module_util.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
from dataclasses import asdict
66

77

8-
def get_torch_module_and_inputs(model_path):
8+
def get_torch_module_and_inputs(model_path, use_dummy_inputs=True):
99
module = _get_torch_module(model_path)
1010
tensor_metas = _get_tensor_metas(model_path)
11-
inputs = _create_inputs_by_metas(module, tensor_metas)
11+
inputs = _create_inputs_by_metas(module, tensor_metas, use_dummy_inputs)
1212
return module, inputs
1313

1414

@@ -27,11 +27,11 @@ def _get_tensor_metas(model_path):
2727
]
2828

2929

30-
def _create_inputs_by_metas(module, tensor_metas):
30+
def _create_inputs_by_metas(module, tensor_metas, use_dummy_inputs):
3131
tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas]
32-
from graph_net.torch.utils import get_dummy_named_tensors
32+
from graph_net.torch.utils import get_named_tensors
3333

34-
named_tensors = get_dummy_named_tensors(tensor_meta_attrs_list)
34+
named_tensors = get_named_tensors(tensor_meta_attrs_list, use_dummy_inputs)
3535
name2tensor = {k: v for k, v in named_tensors}
3636
return tuple(
3737
name2tensor[name] for name in inspect.signature(module.forward).parameters

graph_net/torch/graph_decomposer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def __call__(self, rel_model_path):
130130
for k, v in self.config.items()
131131
if k in {"split_positions", "group_head_and_tail", "chain_style"}
132132
}
133-
module, inputs = get_torch_module_and_inputs(model_path)
133+
module, inputs = get_torch_module_and_inputs(model_path, use_dummy_inputs=False)
134134
gm = parse_immutable_model_path_into_sole_graph_module(model_path)
135135
try:
136136
# logger.warning("convert_to_submodules_graph-call-begin")
@@ -227,7 +227,7 @@ def __call__(self, rel_model_path):
227227
"group_head_and_tail": self.config.get("group_head_and_tail", False),
228228
"chain_style": self.config.get("chain_style", False),
229229
}
230-
module, inputs = get_torch_module_and_inputs(model_path)
230+
module, inputs = get_torch_module_and_inputs(model_path, use_dummy_inputs=False)
231231
gm = parse_sole_graph_module(module, inputs)
232232
rewrited_gm: torch.fx.GraphModule = convert_to_submodules_graph(
233233
gm,

0 commit comments

Comments
 (0)