Skip to content

Commit abfd75a

Browse files
committed
Support try-run when generate_main is not enabled.
1 parent ddc026b commit abfd75a

File tree

1 file changed

+21
-15
lines changed

1 file changed

+21
-15
lines changed

graph_net/sample_pass/agent_unittest_generator.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def __init__(
224224
self.output_dir = Path(output_dir)
225225
self.device = self._choose_device(device)
226226
self.generate_main = generate_main
227-
self.try_run = try_run and generate_main
227+
self.try_run = try_run
228228
self.data_input_predicator = self._make_data_input_predicator(
229229
data_input_predicator_filepath, data_input_predicator_class_name
230230
)
@@ -244,20 +244,26 @@ def generate(self):
244244
input_tensor_metas,
245245
weight_tensor_metas,
246246
) = self._get_input_and_weight_tensor_metas(input_arg_names, weight_arg_names)
247-
graph_module_desc = GraphModuleDescriptor(
248-
device=self.device,
249-
generate_main=self.generate_main,
250-
model_name=model_name,
251-
input_arg_names=input_arg_names,
252-
input_tensor_metas=input_tensor_metas,
253-
weight_arg_names=weight_arg_names,
254-
weight_tensor_metas=weight_tensor_metas,
255-
forward_body=self._get_forward_body(
256-
graph_module, input_arg_names, weight_arg_names
257-
),
258-
)
259-
unittest = self._render_template(graph_module_desc)
260-
if self._try_to_run_unittest(unittest):
247+
248+
def _generate_unittest(generate_main):
249+
graph_module_desc = GraphModuleDescriptor(
250+
device=self.device,
251+
generate_main=generate_main,
252+
model_name=model_name,
253+
input_arg_names=input_arg_names,
254+
input_tensor_metas=input_tensor_metas,
255+
weight_arg_names=weight_arg_names,
256+
weight_tensor_metas=weight_tensor_metas,
257+
forward_body=self._get_forward_body(
258+
graph_module, input_arg_names, weight_arg_names
259+
),
260+
)
261+
return self._render_template(graph_module_desc)
262+
263+
# Generate unittest with main for try-run.
264+
unittest_for_try_run = _generate_unittest(generate_main=self.try_run)
265+
if self._try_to_run_unittest(unittest_for_try_run):
266+
unittest = _generate_unittest(generate_main=self.generate_main)
261267
self._write_to_file(unittest, self.output_dir)
262268

263269
def _choose_device(self, device) -> str:

0 commit comments

Comments
 (0)