Skip to content

Commit 2353301

Browse files
committed
Enable resume.
1 parent 2ed0c1a commit 2353301

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

graph_net/test/agent_unittest_generator_test.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env bash
22

3-
ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)"
3+
4+
45
GRAPH_NET_ROOT=$(python -c "import graph_net, os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
56
OUTPUT_DIR="/tmp/agent_unittests"
67

@@ -14,6 +15,7 @@ HANDLER_CONFIG=$(base64 -w 0 <<EOF
1415
"device": "auto",
1516
"generate_main": true,
1617
"try_run": true,
18+
"resume": false,
1719
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/graph_net/torch/constraint_util.py",
1820
"data_input_predicator_class_name": "NaiveDataInputPredicator"
1921
}

graph_net/torch/sample_passes/agent_unittest_generator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from graph_net import imp_util
1616
from graph_net.sample_pass.sample_pass import SamplePass
17+
from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin
1718
from graph_net.tensor_meta import TensorMeta
1819

1920

@@ -298,7 +299,7 @@ def _render_template(self, graph_module_desc):
298299
return template.render(graph_module_desc=graph_module_desc)
299300

300301

301-
class AgentUnittestGeneratorPass(SamplePass):
302+
class AgentUnittestGeneratorPass(SamplePass, ResumableSamplePassMixin):
302303
"""SamplePass wrapper to generate Torch unittests via model_path_handler."""
303304

304305
def __init__(self, config=None):
@@ -313,10 +314,15 @@ def declare_config(
313314
try_run: bool = False,
314315
data_input_predicator_filepath: str = None,
315316
data_input_predicator_class_name: str = None,
317+
resume: bool = False,
318+
limits_handled_models: int = None,
316319
):
317320
pass
318321

319322
def __call__(self, rel_model_path: str):
323+
self.resumable_handle_sample(rel_model_path)
324+
325+
def resume(self, rel_model_path: str):
320326
model_path_prefix = Path(self.config["model_path_prefix"])
321327
output_dir = Path(self.config["output_dir"])
322328
generator = AgentUnittestGenerator(

0 commit comments

Comments
 (0)