Skip to content

Commit 5885e0c

Browse files
authored
[Feature Enhancement] Opimize the unittest generator in llm format for torch. (#452)
* Support to generate unittest for llm. * Add RenamedDataInputPredicator to support generate unittest for renamed subgraphs.
1 parent c24ad82 commit 5885e0c

File tree

4 files changed

+287
-259
lines changed

4 files changed

+287
-259
lines changed

graph_net/sample_pass/sample_pass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ def _make_config_by_config_declare(self, config):
5555
sig = inspect.signature(self.declare_config)
5656
mut_config = copy.deepcopy(config)
5757
for name, param in sig.parameters.items():
58-
self._complete_default(name, param, mut_config)
58+
if name not in mut_config:
59+
self._complete_default(name, param, mut_config)
5960
class_name = type(self).__name__
6061
assert name in mut_config, f"{name=} {class_name=}"
6162

graph_net/test/test_agent_unittest_generator.sh renamed to graph_net/test/agent_unittest_generator_test.sh

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
#!/usr/bin/env bash
2-
# set -euo pipefail
3-
4-
# Smoke tests for AgentUnittestGenerator using model_path_handler + sample pass.
52

63
ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)"
74
GRAPH_NET_ROOT=$(python -c "import graph_net, os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
8-
HANDLER_PATH="$GRAPH_NET_ROOT/graph_net/torch/sample_passes/agent_unittest_generator.py"
5+
96
MODEL_PATH_PREFIX="$ROOT_DIR"
10-
OUTPUT_DIR="$ROOT_DIR"
7+
OUTPUT_DIR="/tmp/agent_unittests"
118

129
HANDLER_CONFIG=$(base64 -w 0 <<EOF
1310
{
14-
"handler_path": "$HANDLER_PATH",
11+
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/sample_passes/agent_unittest_generator.py",
1512
"handler_class_name": "AgentUnittestGeneratorPass",
1613
"handler_config": {
1714
"model_path_prefix": "$MODEL_PATH_PREFIX",
1815
"output_dir": "$OUTPUT_DIR",
19-
"force_device": "auto",
20-
"use_dummy_inputs": false
16+
"device": "auto",
17+
"generate_main": true,
18+
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/graph_net/torch/constraint_util.py",
19+
"data_input_predicator_class_name": "NaiveDataInputPredicator"
2120
}
2221
}
2322
EOF

graph_net/torch/constraint_util.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@ def __call__(self, model_path, input_var_name: str) -> bool:
1616
)
1717

1818

19+
class RenamedDataInputPredicator:
20+
def __init__(self, config):
21+
self.config = config
22+
23+
def __call__(self, model_path, input_var_name: str) -> bool:
24+
return input_var_name.startswith("in_")
25+
26+
1927
class ModelRunnablePredicator:
2028
def __init__(self, config):
2129
if config is None:

0 commit comments

Comments
 (0)