@@ -81,6 +81,18 @@ def get_input_spec(args):
8181 return input_spec
8282
8383
84+ def regular_item (item ):
85+ if isinstance (item , paddle .Tensor ) and (
86+ item .dtype == paddle .bfloat16 or item .dtype == paddle .bfloat32
87+ ):
88+ item = np .array (item .astype ("float32" ))
89+ else :
90+ item = np .array (item )
91+ if item .dtype == np .bool_ :
92+ item = item .astype ("float32" )
93+ return item
94+
95+
8496def test_single_model (args ):
8597 synchronizer_func = get_synchronizer_func (args )
8698 input_dict = get_input_dict (args )
@@ -115,24 +127,15 @@ def test_single_model(args):
115127 with naive_timer (compiled_duration_box , synchronizer_func ):
116128 compiled_out = compiled_model (** input_dict )
117129 if isinstance (expected_out , paddle .Tensor ):
118- expected_out = [expected_out .numpy ().astype ("float32" )]
119- compiled_out = [compiled_out .numpy ().astype ("float32" )]
120- elif isinstance (expected_out , list ) or isinstance (expected_out , tuple ):
121- if isinstance (expected_out , tuple ):
122- expected_out = list (expected_out )
123- compiled_out = list (compiled_out )
124- new_expected = [
125- np .array (item ).astype ("float32" )
126- for item in expected_out
127- if np .array (item ).size != 0
130+ expected_out = [expected_out ]
131+ compiled_out = [compiled_out ]
132+ if isinstance (expected_out , list ) or isinstance (expected_out , tuple ):
133+ expected_out = [
134+ regular_item (item ) for item in expected_out if np .array (item ).size != 0
128135 ]
129- new_compiled = [
130- np .array (item ).astype ("float32" )
131- for item in compiled_out
132- if np .array (item ).size != 0
136+ compiled_out = [
137+ regular_item (item ) for item in compiled_out if np .array (item ).size != 0
133138 ]
134- expected_out = new_expected
135- compiled_out = new_compiled
136139 else :
137140 raise ValueError ("Illegal return value." )
138141
0 commit comments