Skip to content

Commit 96f2658

Browse files
committed
Support to generate unittest for llm.
1 parent 3af85cf commit 96f2658

File tree

3 files changed

+245
-259
lines changed

3 files changed

+245
-259
lines changed

graph_net/sample_pass/sample_pass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ def _make_config_by_config_declare(self, config):
5555
sig = inspect.signature(self.declare_config)
5656
mut_config = copy.deepcopy(config)
5757
for name, param in sig.parameters.items():
58-
self._complete_default(name, param, mut_config)
58+
if name not in mut_config:
59+
self._complete_default(name, param, mut_config)
5960
class_name = type(self).__name__
6061
assert name in mut_config, f"{name=} {class_name=}"
6162

graph_net/test/test_agent_unittest_generator.sh renamed to graph_net/test/agent_unittest_generator_test.sh

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
#!/usr/bin/env bash
2-
# set -euo pipefail
3-
4-
# Smoke tests for AgentUnittestGenerator using model_path_handler + sample pass.
52

63
ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)"
74
GRAPH_NET_ROOT=$(python -c "import graph_net, os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
8-
HANDLER_PATH="$GRAPH_NET_ROOT/graph_net/torch/sample_passes/agent_unittest_generator.py"
5+
96
MODEL_PATH_PREFIX="$ROOT_DIR"
10-
OUTPUT_DIR="$ROOT_DIR"
7+
OUTPUT_DIR="/tmp/agent_unittests"
118

129
HANDLER_CONFIG=$(base64 -w 0 <<EOF
1310
{
14-
"handler_path": "$HANDLER_PATH",
11+
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/sample_passes/agent_unittest_generator.py",
1512
"handler_class_name": "AgentUnittestGeneratorPass",
1613
"handler_config": {
1714
"model_path_prefix": "$MODEL_PATH_PREFIX",
1815
"output_dir": "$OUTPUT_DIR",
19-
"force_device": "auto",
20-
"use_dummy_inputs": false
16+
"device": "auto",
17+
"generate_main": true,
18+
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/graph_net/torch/constraint_util.py",
19+
"data_input_predicator_class_name": "NaiveDataInputPredicator"
2120
}
2221
}
2322
EOF

0 commit comments

Comments
 (0)