File tree Expand file tree Collapse file tree 1 file changed +11
-3
lines changed
Expand file tree Collapse file tree 1 file changed +11
-3
lines changed Original file line number Diff line number Diff line change @@ -114,11 +114,19 @@ def test_single_model(args):
114114 compiled_duration_box = DurationBox (- 1 )
115115 with naive_timer (compiled_duration_box , synchronizer_func ):
116116 compiled_out = compiled_model (** input_dict )
117- expected_out = expected_out .numpy ()
118- compiled_out = compiled_out .numpy ()
117+
118+ expected_out_list = (
119+ expected_out if isinstance (expected_out , (list , tuple )) else [expected_out ]
120+ )
121+ compiled_out_list = (
122+ compiled_out if isinstance (compiled_out , (list , tuple )) else [compiled_out ]
123+ )
124+
125+ processed_expected_out = [t .numpy () for t in expected_out_list ]
126+ processed_compiled_out = [t .numpy () for t in compiled_out_list ]
119127
120128 def print_cmp (key , func , ** kwargs ):
121- cmp_ret = func (expected_out , compiled_out , ** kwargs )
129+ cmp_ret = func (processed_expected_out , processed_compiled_out , ** kwargs )
122130 print (
123131 f"{ args .log_prompt } { key } model_path:{ args .model_path } { cmp_ret } " ,
124132 file = sys .stderr ,
You can’t perform that action at this time.
0 commit comments