File tree Expand file tree Collapse file tree 2 files changed +12
-4
lines changed
Expand file tree Collapse file tree 2 files changed +12
-4
lines changed Original file line number Diff line number Diff line change @@ -124,15 +124,23 @@ def test_single_model(args):
124124 compiled_duration_box = DurationBox (- 1 )
125125 with naive_timer (compiled_duration_box , synchronizer_func ):
126126 compiled_out = compiled_model (** input_dict )
127+
127128 if isinstance (expected_out , paddle .Tensor ):
128129 expected_out = [expected_out ]
129130 compiled_out = [compiled_out ]
130131 if isinstance (expected_out , list ) or isinstance (expected_out , tuple ):
132+ for a , b in zip (expected_out , compiled_out ):
133+ if (a is None and b is not None ) or (a is not None and b is None ):
134+ raise ValueError ("Both expected_out and compiled_out must be not None." )
131135 expected_out = [
132- regular_item (item ) for item in expected_out if np .array (item ).size != 0
136+ regular_item (item )
137+ for item in expected_out
138+ if item is not None and np .array (item ).size != 0
133139 ]
134140 compiled_out = [
135- regular_item (item ) for item in compiled_out if np .array (item ).size != 0
141+ regular_item (item )
142+ for item in compiled_out
143+ if item is not None and np .array (item ).size != 0
136144 ]
137145 else :
138146 raise ValueError ("Illegal return value." )
Original file line number Diff line number Diff line change @@ -66,9 +66,9 @@ def main(args):
6666 params .update (inputs )
6767 state_dict = {k : utils .replay_tensor (v ) for k , v in params .items ()}
6868
69- y = model (** state_dict )[ 0 ]
69+ y = model (** state_dict )
7070
71- print (np .argmin (y ), np .argmax (y ))
71+ # print(np.argmin(y), np.argmax(y))
7272 if isinstance (y , paddle .Tensor ):
7373 print (y .shape )
7474 elif isinstance (y , list ) or isinstance (y , tuple ):
You can’t perform that action at this time.
0 commit comments