@@ -152,7 +152,7 @@ def forward(self, x):
152152
153153class Tester (TestCase ):
154154
155- def _test_output (self , model , x , kind = None ):
155+ def _test_output (self , model , x , kind_in_graph = None , kind_not_in_graph = None ):
156156 modelName = model .__class__ .__name__
157157 core .disable_jit_opt ()
158158 core .disable_mix_bf16_fp32 ()
@@ -164,180 +164,207 @@ def _test_output(self, model, x, kind=None):
164164
165165 script_model = torch .jit .script (model )
166166 script_model .eval ()
167+
168+ trace_model = torch .jit .trace (model , x )
169+ trace_model .eval ()
167170 with torch .no_grad ():
168171 sresult = script_model (x )
172+ tresult = trace_model (x )
169173
170174 self .assertEqual (result , sresult )
175+ self .assertEqual (result , tresult )
171176
172177 core .enable_jit_opt ()
173- fused_model = torch .jit .script (model )
178+ script_fused_model = torch .jit .script (model )
179+ trace_fused_model = torch .jit .trace (model , x )
174180 with torch .no_grad ():
175181 # conv relu fusion, conv sum fusion or conv sum relu fusion
176- graph = fused_model .graph_for (x )
177- # print(graph)
178- fresult = fused_model (x )
182+ script_graph = script_fused_model .graph_for (x )
183+ fused_sresult = script_fused_model (x )
179184
180- # print(result)
181- # print(sresult)
182- # print(fresult)
185+ trace_graph = trace_fused_model .graph_for (x )
186+ fused_tresult = trace_fused_model (x )
183187
184- self .assertEqual (result , fresult )
188+ self .assertEqual (result , fused_sresult )
189+ self .assertEqual (result , fused_tresult )
185190
186191 # check if the fused node exists in the graph
187- if kind is not None :
188- self .assertTrue (any (n .kind () == kind for n in graph .nodes ()))
192+ if kind_in_graph is not None :
193+ self .assertTrue (any (n .kind () == kind_in_graph for n in script_graph .nodes ()))
194+ self .assertTrue (any (n .kind () == kind_in_graph for n in trace_graph .nodes ()))
195+
196+ # check if certain node does not exist in the graph
197+ if kind_not_in_graph is not None :
198+ self .assertTrue (all (n .kind () != kind_not_in_graph for n in script_graph .nodes ()))
199+ self .assertTrue (all (n .kind () != kind_not_in_graph for n in trace_graph .nodes ()))
200+
189201
190- def _test_output_bf16 (self , model , x , kind = None , prec = None ):
202+ def _test_output_bf16 (self , model , x , kind_in_graph = None , kind_not_in_graph = None , prec = None ):
191203 modelName = model .__class__ .__name__
192204
193205 core .enable_auto_dnnl ()
194206 core .enable_jit_opt ()
195- core .disable_mix_bf16_fp32 ()
196-
207+ core .enable_mix_bf16_fp32 ()
208+
197209 model = model .to (ipex .DEVICE ).eval ()
198210 x = x .to (ipex .DEVICE )
199211 x2 = x .clone ()
212+ x3 = x .clone ()
213+
214+ script_fused_model = torch .jit .script (copy .deepcopy (model ))
215+ trace_fused_model = torch .jit .trace (copy .deepcopy (model ), x3 )
200216
201- fused_model = torch .jit .script (copy .deepcopy (model ))
202-
203- # bn folding, removing it after solve some issue, using mix_preci? to check
204- core .disable_auto_dnnl ()
205- fused_model = wrap_cpp_module (torch ._C ._jit_pass_fold_convbn (fused_model ._c ))
206- core .enable_auto_dnnl ()
207-
208- core .enable_mix_bf16_fp32 ()
209217
210218 with torch .no_grad ():
211219 # bf16, native path
212220 result = model (x )
213- # bf16, jit path
214- graph = fused_model .graph_for (x2 )
215- # print(graph)
216- fresult = fused_model (x2 )
217-
218- #print(result)
219- #print(fresult)
220-
221- self .assertEqual (fresult , result , prec = prec )
221+ # bf16, jit script path
222+ script_graph = script_fused_model .graph_for (x2 )
223+ fused_sresult = script_fused_model (x2 )
224+ # bf 16, jit trace path
225+ trace_graph = trace_fused_model .graph_for (x3 )
226+ fused_tresult = trace_fused_model (x3 )
227+
228+ # disable mix_bf16_fp32 when the calculation is done
229+ # to avoid affecting other scripts
230+ core .disable_mix_bf16_fp32 ()
231+
232+ self .assertEqual (fused_sresult , result , prec = prec )
233+ self .assertEqual (fused_tresult , result , prec = prec )
222234
223235 # check if the fused node exists in the graph
224- if kind is not None :
225- self .assertTrue (any (n .kind () == kind for n in graph .nodes ()))
226-
236+ if kind_in_graph is not None :
237+ self .assertTrue (any (n .kind () == kind_in_graph for n in script_graph .nodes ()))
238+ self .assertTrue (any (n .kind () == kind_in_graph for n in trace_graph .nodes ()))
239+
240+ # check if certain node does not exist in the graph
241+ if kind_not_in_graph is not None :
242+ self .assertTrue (all (n .kind () != kind_not_in_graph for n in script_graph .nodes ()))
243+ self .assertTrue (all (n .kind () != kind_not_in_graph for n in trace_graph .nodes ()))
244+
227245
228246 def test_output_conv_bn_2d (self ):
229247 self ._test_output (
230248 ConvBatchNorm_Fixed (2 , 3 , 32 , kernel_size = 3 , stride = 1 ),
231- torch .randn (32 , 3 , 224 , 224 ),
232- kind = "aten::conv2d" )
249+ torch .randn (32 , 3 , 64 , 64 ),
250+ kind_in_graph = "aten::conv2d" ,
251+ kind_not_in_graph = "aten::batch_norm" ,)
233252 self ._test_output_bf16 (
234253 ConvBatchNorm_Fixed (2 , 3 , 32 , kernel_size = 3 , stride = 1 ),
235- torch .randn (32 , 3 , 224 , 224 ),
236- kind = "aten::conv2d" ,
254+ torch .randn (32 , 3 , 64 , 64 ),
255+ kind_in_graph = "aten::conv2d" ,
256+ kind_not_in_graph = "aten::batch_norm" ,
237257 prec = 0.02 )
238258
239259
240260 def test_output_conv_bn_3d (self ):
241261 self ._test_output (
242262 ConvBatchNorm_Fixed (3 , 3 , 32 , kernel_size = 3 , stride = 1 ),
243- torch .randn (32 , 3 , 112 , 112 , 112 ),
244- kind = "aten::conv3d" )
263+ torch .randn (32 , 3 , 32 , 32 , 32 ),
264+ kind_in_graph = "aten::conv3d" ,
265+ kind_not_in_graph = "aten::batch_norm" ,)
245266 self ._test_output_bf16 (
246267 ConvBatchNorm_Fixed (3 , 3 , 32 , kernel_size = 3 , stride = 1 ),
247- torch .randn (32 , 3 , 112 , 112 , 112 ),
248- kind = "aten::conv3d" ,
268+ torch .randn (32 , 3 , 32 , 32 , 32 ),
269+ kind_in_graph = "aten::conv3d" ,
270+ kind_not_in_graph = "aten::batch_norm" ,
249271 prec = 0.02 )
250272
251273
252274 def test_output_conv_relu_2d (self ):
253275 self ._test_output (
254276 ConvRelu_Fixed (2 , 3 , 32 , kernel_size = 3 , stride = 1 ),
255- torch .randn (32 , 3 , 224 , 224 ),
256- kind = "ipex::conv2d_relu" )
277+ torch .randn (32 , 3 , 64 , 64 ),
278+ kind_in_graph = "ipex::conv2d_relu" )
257279 self ._test_output_bf16 (
258280 ConvRelu_Fixed (2 , 3 , 32 , kernel_size = 3 , stride = 1 ),
259- torch .randn (32 , 3 , 224 , 224 ),
260- kind = "ipex::conv2d_relu" )
281+ torch .randn (32 , 3 , 64 , 64 ),
282+ kind_in_graph = "ipex::conv2d_relu" )
261283
262284
263285 def test_output_conv_relu_3d (self ):
264286 self ._test_output (
265287 ConvRelu_Fixed (3 , 3 , 32 , kernel_size = 3 , stride = 1 ),
266- torch .randn (32 , 3 , 112 , 112 , 112 ),
267- kind = "ipex::conv3d_relu" )
288+ torch .randn (32 , 3 , 32 , 32 , 32 ),
289+ kind_in_graph = "ipex::conv3d_relu" )
268290 self ._test_output_bf16 (
269291 ConvRelu_Fixed (3 , 3 , 32 , kernel_size = 3 , stride = 1 ),
270- torch .randn (32 , 3 , 112 , 112 , 112 ),
271- kind = "ipex::conv3d_relu" )
292+ torch .randn (32 , 3 , 32 , 32 , 32 ),
293+ kind_in_graph = "ipex::conv3d_relu" )
272294
273295
274296 def test_output_conv_sum_2d (self ):
275297 self ._test_output (
276298 ConvSum (2 , 3 , 32 , kernel_size = 3 , stride = 1 ),
277- torch .randn (32 , 3 , 224 , 224 ),
278- kind = "ipex::conv2d_sum" )
299+ torch .randn (32 , 3 , 64 , 64 ),
300+ kind_in_graph = "ipex::conv2d_sum" )
279301 self ._test_output_bf16 (
280302 ConvSum (2 , 3 , 32 , kernel_size = 3 , stride = 1 ),
281- torch .randn (32 , 3 , 224 , 224 ),
282- kind = "ipex::conv2d_sum" ,
303+ torch .randn (32 , 3 , 64 , 64 ),
304+ kind_in_graph = "ipex::conv2d_sum" ,
283305 prec = 0.04 )
284306
285307
286308 def test_output_conv_sum_3d (self ):
287309 self ._test_output (
288310 ConvSum (3 , 3 , 32 , kernel_size = 3 , stride = 1 ),
289- torch .randn (32 , 3 , 112 , 112 , 112 ),
290- kind = "ipex::conv3d_sum" )
311+ torch .randn (32 , 3 , 32 , 32 , 32 ),
312+ kind_in_graph = "ipex::conv3d_sum" )
291313 self ._test_output_bf16 (
292314 ConvSum (3 , 3 , 32 , kernel_size = 3 , stride = 1 ),
293- torch .randn (32 , 3 , 112 , 112 , 112 ),
294- kind = "ipex::conv3d_sum" ,
315+ torch .randn (32 , 3 , 32 , 32 , 32 ),
316+ kind_in_graph = "ipex::conv3d_sum" ,
295317 prec = 0.04 )
296318
297319
298320 def test_output_cascaded_conv_bn_sum_relu_2d (self ):
299321 self ._test_output (
300322 CascadedConvBnSumRelu (2 , 3 , 64 , 32 , kernel_size = 3 , stride = 1 ),
301- torch .rand (32 , 3 , 224 , 224 ),
302- kind = "ipex::conv2d_sum_relu" )
323+ torch .rand (32 , 3 , 64 , 64 ),
324+ kind_in_graph = "ipex::conv2d_sum_relu" ,
325+ kind_not_in_graph = "aten::batch_norm" )
303326 self ._test_output_bf16 (
304327 CascadedConvBnSumRelu (2 , 3 , 64 , 32 , kernel_size = 3 , stride = 1 ),
305- torch .rand (32 , 3 , 224 , 224 ),
306- kind = "ipex::conv2d_sum_relu" ,
328+ torch .rand (32 , 3 , 64 , 64 ),
329+ kind_in_graph = "ipex::conv2d_sum_relu" ,
330+ kind_not_in_graph = "aten::batch_norm" ,
307331 prec = 0.02 )
308332
309333
310334 def test_output_cascaded_conv_bn_sum_relu_3d (self ):
311335 self ._test_output (
312336 CascadedConvBnSumRelu (3 , 3 , 64 , 32 , kernel_size = 3 , stride = 1 ),
313- torch .rand (32 , 3 , 112 , 112 , 112 ),
314- kind = "ipex::conv3d_sum_relu" )
337+ torch .rand (32 , 3 , 32 , 32 , 32 ),
338+ kind_in_graph = "ipex::conv3d_sum_relu" ,
339+ kind_not_in_graph = "aten::batch_norm" ,)
315340 self ._test_output_bf16 (
316341 CascadedConvBnSumRelu (3 , 3 , 64 , 32 , kernel_size = 3 , stride = 1 ),
317- torch .rand (32 , 3 , 112 , 112 , 112 ),
318- kind = "ipex::conv3d_sum_relu" ,
342+ torch .rand (32 , 3 , 32 , 32 , 32 ),
343+ kind_in_graph = "ipex::conv3d_sum_relu" ,
344+ kind_not_in_graph = "aten::batch_norm" ,
319345 prec = 0.02 )
320346
321347
322348 def test_output_linear_relu (self ):
323349 self ._test_output (
324350 LinearRelu (3 , 32 , bias = True ),
325351 torch .rand (32 , 3 ),
326- kind = "ipex::linear_relu" )
352+ kind_in_graph = "ipex::linear_relu" )
327353 self ._test_output_bf16 (
328354 LinearRelu (3 , 32 , bias = True ),
329355 torch .rand (32 , 3 ),
330- kind = "ipex::linear_relu" )
356+ kind_in_graph = "ipex::linear_relu" )
331357 self ._test_output (
332358 LinearRelu (3 , 32 , bias = False ),
333359 torch .rand (32 , 3 ),
334- kind = "ipex::linear_relu" )
360+ kind_in_graph = "ipex::linear_relu" )
335361 self ._test_output_bf16 (
336362 LinearRelu (3 , 32 , bias = False ),
337363 torch .rand (32 , 3 ),
338- kind = "ipex::linear_relu" )
364+ kind_in_graph = "ipex::linear_relu" )
339365
340366
341367if __name__ == '__main__' :
368+ torch .manual_seed (2020 )
342369 core .enable_auto_dnnl ()
343370 test = unittest .main ()
0 commit comments