@@ -224,7 +224,7 @@ def __init__(
224224 self .output_dir = Path (output_dir )
225225 self .device = self ._choose_device (device )
226226 self .generate_main = generate_main
227- self .try_run = try_run and generate_main
227+ self .try_run = try_run
228228 self .data_input_predicator = self ._make_data_input_predicator (
229229 data_input_predicator_filepath , data_input_predicator_class_name
230230 )
@@ -244,20 +244,26 @@ def generate(self):
244244 input_tensor_metas ,
245245 weight_tensor_metas ,
246246 ) = self ._get_input_and_weight_tensor_metas (input_arg_names , weight_arg_names )
247- graph_module_desc = GraphModuleDescriptor (
248- device = self .device ,
249- generate_main = self .generate_main ,
250- model_name = model_name ,
251- input_arg_names = input_arg_names ,
252- input_tensor_metas = input_tensor_metas ,
253- weight_arg_names = weight_arg_names ,
254- weight_tensor_metas = weight_tensor_metas ,
255- forward_body = self ._get_forward_body (
256- graph_module , input_arg_names , weight_arg_names
257- ),
258- )
259- unittest = self ._render_template (graph_module_desc )
260- if self ._try_to_run_unittest (unittest ):
247+
248+ def _generate_unittest (generate_main ):
249+ graph_module_desc = GraphModuleDescriptor (
250+ device = self .device ,
251+ generate_main = generate_main ,
252+ model_name = model_name ,
253+ input_arg_names = input_arg_names ,
254+ input_tensor_metas = input_tensor_metas ,
255+ weight_arg_names = weight_arg_names ,
256+ weight_tensor_metas = weight_tensor_metas ,
257+ forward_body = self ._get_forward_body (
258+ graph_module , input_arg_names , weight_arg_names
259+ ),
260+ )
261+ return self ._render_template (graph_module_desc )
262+
263+ # Generate unittest with main for try-run.
264+ unittest_for_try_run = _generate_unittest (generate_main = self .try_run )
265+ if self ._try_to_run_unittest (unittest_for_try_run ):
266+ unittest = _generate_unittest (generate_main = self .generate_main )
261267 self ._write_to_file (unittest , self .output_dir )
262268
263269 def _choose_device (self , device ) -> str :
0 commit comments