@@ -33,6 +33,13 @@ def get_rand_seed():
3333 return int (time .time () * 1000000000 )
3434
3535device = ipex .DEVICE
36+
37+ def convert_blocked (t ):
38+ assert t .dim () == 4 , "only support converting 4d tensor"
39+ c = t .size (1 )
40+ t = t .clone ().to (device )
41+ return F .conv2d (t , torch .ones (c , 1 , 1 , 1 ).to (device ), groups = c )
42+
3643class TestConv (TestCase ):
3744 def test_Conv2d_with_cpu (self ):
3845 rand_seed = int (get_rand_seed ())
@@ -202,6 +209,78 @@ def test_mul_(self):
202209 a2 = self ._test_mul_ ('cpu' , rand_seed )
203210 self .assertEqual (a2 , a1 .to ('cpu' ))
204211
212+ def test_mixed_format (self ):
213+ ipex .core .enable_auto_dnnl ()
214+ rand_seed = int (get_rand_seed ())
215+ print ("{} rand sed: {}" .format (sys ._getframe ().f_code .co_name , rand_seed ))
216+ torch .manual_seed (rand_seed )
217+
218+ shape = (2 , 3 , 4 , 5 )
219+
220+ for fname in ['add' , 'mul' ]:
221+
222+ x_cpu = torch .ones (shape ) * 5
223+ y_cpu = torch .ones (shape ) * 4
224+
225+ # block tensor is a dpcpp tensor
226+ x_plain = x_cpu .clone ().to (device )
227+ y_plain = y_cpu .clone ().to (device )
228+ x_block = convert_blocked (x_cpu .clone ())
229+ y_block = convert_blocked (y_cpu .clone ())
230+
231+ fn = getattr (torch , fname )
232+ ref = fn (x_cpu , y_cpu )
233+
234+ # test add, mul
235+ def test_outplace (a , b ):
236+ a = a .clone ()
237+ b = b .clone ()
238+ self .assertEqual (fn (a , b ), ref )
239+
240+ test_outplace (x_plain , y_plain )
241+ test_outplace (x_plain , y_block )
242+ test_outplace (y_block , x_plain )
243+ test_outplace (x_block , y_block )
244+
245+ # test add_out, mul_out
246+ def test_out (a , b , o ):
247+ a = a .clone ()
248+ b = b .clone ()
249+ o = o .clone ()
250+ y = fn (a , b , out = o )
251+ self .assertEqual (y , ref )
252+ self .assertEqual (o , ref )
253+
254+ out = torch .ones (shape ).to (device )
255+ test_out (x_plain , y_plain , out )
256+ test_out (x_plain , y_block , out )
257+ test_out (y_block , x_plain , out )
258+ test_out (x_block , y_block , out )
259+ out = torch .ones (1 ).to (device )
260+ test_out (x_plain , y_plain , out )
261+ test_out (x_plain , y_block , out )
262+ test_out (y_block , x_plain , out )
263+ test_out (x_block , y_block , out )
264+
265+ # test add_, mul_
266+ def test_inplace (a , b ):
267+ a = a .clone ()
268+ b = b .clone ()
269+ y = getattr (a , fname + '_' )(b )
270+ self .assertEqual (a , ref )
271+ self .assertEqual (y , ref )
272+
273+ test_inplace (x_plain , y_plain )
274+ test_inplace (x_plain , y_block )
275+ test_inplace (y_block , x_plain )
276+ test_inplace (x_block , y_block )
277+
278+ # test broadcast
279+ scalar = torch .ones (1 ).to (device )
280+ self .assertEqual (fn (x_plain , scalar ), fn (x_cpu , scalar ))
281+ self .assertEqual (fn (scalar , x_plain ), fn (scalar , x_cpu ))
282+
283+
205284class TestRelu (TestCase ):
206285 def _test_relu_ (self , device , rand_seed ):
207286 torch .manual_seed (rand_seed )
@@ -388,6 +467,11 @@ def test_addmm(self):
388467 torch .addmm (input = res_dpcpp , mat1 = b1_dpcpp , mat2 = b2_dpcpp , alpha = alpha , beta = beta , out = y_dpcpp )
389468 self .assertEqual (y_cpu , y_dpcpp )
390469
470+ res_cpu .addmm_ (mat1 = b1_cpu , mat2 = b2_cpu , alpha = alpha , beta = beta )
471+ res_dpcpp .addmm_ (mat1 = b1_cpu , mat2 = b2_cpu , alpha = alpha , beta = beta )
472+ self .assertEqual (res_cpu , res_dpcpp )
473+
474+
391475 def test_addbmm (self ):
392476 ipex .core .enable_auto_dnnl ()
393477 rand_seed = int (get_rand_seed ())
@@ -415,6 +499,10 @@ def test_addbmm(self):
415499 torch .addbmm (res_dpcpp , b1_dpcpp , b2_dpcpp , beta = beta , alpha = alpha , out = y_dpcpp )
416500 self .assertEqual (y_cpu , y_dpcpp , 1e-4 )
417501
502+ res_cpu .addbmm_ (b1_cpu , b2_cpu , beta = beta , alpha = alpha )
503+ res_dpcpp .addbmm_ (b1_dpcpp , b2_dpcpp , beta = beta , alpha = alpha )
504+ self .assertEqual (res_cpu , res_dpcpp , 1e-4 )
505+
418506 def test_baddbmm (self ):
419507 ipex .core .enable_auto_dnnl ()
420508 rand_seed = int (get_rand_seed ())
@@ -441,6 +529,9 @@ def test_baddbmm(self):
441529 torch .baddbmm (res_cpu , b1_cpu , b2_cpu , alpha = alpha , beta = beta , out = y_cpu ),
442530 torch .baddbmm (res_dpcpp , b1_dpcpp , b2_dpcpp , alpha = alpha , beta = beta , out = y_dpcpp ),
443531 self .assertEqual (y_cpu , y_dpcpp )
532+ res_cpu .baddbmm_ (b1_cpu , b2_cpu , alpha = alpha , beta = beta )
533+ res_dpcpp .baddbmm_ (b1_cpu , b2_cpu , alpha = alpha , beta = beta )
534+ self .assertEqual (res_cpu , res_dpcpp )
444535
445536class TestLinear (TestCase ):
446537 def test_linear (self ):
0 commit comments