@@ -73,12 +73,26 @@ def get_input_spec(args):
7373 inputs_params_list = utils .load_converted_list_from_text (f"{ args .model_path } " )
7474 input_spec = [None ] * len (inputs_params_list )
7575 for i , v in enumerate (inputs_params_list ):
76+ name = v ["name" ]
7677 dtype = v ["info" ]["dtype" ]
7778 shape = v ["info" ]["shape" ]
79+ # print(f"-- i: {i}, v: name={name}, shape={shape}, dtype={dtype}")
7880 input_spec [i ] = paddle .static .InputSpec (shape , dtype )
7981 return input_spec
8082
8183
84+ def regular_item (item ):
85+ if isinstance (item , paddle .Tensor ) and (
86+ item .dtype == paddle .bfloat16 or item .dtype == paddle .bfloat32
87+ ):
88+ item = np .array (item .astype ("float32" ))
89+ else :
90+ item = np .array (item )
91+ if item .dtype == np .bool_ :
92+ item = item .astype ("float32" )
93+ return item
94+
95+
8296def test_single_model (args ):
8397 synchronizer_func = get_synchronizer_func (args )
8498 input_dict = get_input_dict (args )
@@ -88,20 +102,18 @@ def test_single_model(args):
88102 build_strategy .build_cinn_pass = False
89103
90104 # eager
91- model = paddle .jit .to_static (
92- model_dy ,
93- full_graph = False ,
94- )
95- model .eval ()
105+ print ("-- Run with eager mode" )
106+ model_dy .eval ()
96107 for _ in range (args .warmup if args .warmup > 0 else 0 ):
97- model (** input_dict )
108+ model_dy (** input_dict )
98109 eager_duration_box = DurationBox (- 1 )
99110 with naive_timer (eager_duration_box , synchronizer_func ):
100- expected_out = model (** input_dict )
111+ expected_out = model_dy (** input_dict )
101112
102113 # compiled
114+ print ("-- Run with compiled mode" )
103115 build_strategy = paddle .static .BuildStrategy ()
104- build_strategy .build_cinn_pass = True
116+ # build_strategy.build_cinn_pass = True
105117 compiled_model = paddle .jit .to_static (
106118 model_dy ,
107119 input_spec = input_spec ,
@@ -115,11 +127,15 @@ def test_single_model(args):
115127 with naive_timer (compiled_duration_box , synchronizer_func ):
116128 compiled_out = compiled_model (** input_dict )
117129 if isinstance (expected_out , paddle .Tensor ):
118- expected_out = expected_out .numpy ()
119- compiled_out = compiled_out .numpy ()
120- elif isinstance (expected_out , list ) or isinstance (expected_out , tuple ):
121- expected_out = expected_out [0 ].numpy ()
122- compiled_out = compiled_out [0 ].numpy ()
130+ expected_out = [expected_out ]
131+ compiled_out = [compiled_out ]
132+ if isinstance (expected_out , list ) or isinstance (expected_out , tuple ):
133+ expected_out = [
134+ regular_item (item ) for item in expected_out if np .array (item ).size != 0
135+ ]
136+ compiled_out = [
137+ regular_item (item ) for item in compiled_out if np .array (item ).size != 0
138+ ]
123139 else :
124140 raise ValueError ("Illegal return value." )
125141
0 commit comments