@@ -486,6 +486,90 @@ def test_batch_norm3d_backward(self):
486486 self .assertTrue (ipex .core .is_bf16_dil_tensor (x_auto_mix_bf16 .grad ))
487487 self .assertEqual (x_man_bf16 .grad .float (), x_auto_mix_bf16 .grad )
488488
489+ class TestLayerNorm (TestCase ):
490+ def test_layer_norm (self ):
491+ rand_seed = int (get_rand_seed ())
492+ print ("{} rand sed: {}" .format (sys ._getframe ().f_code .co_name , rand_seed ))
493+
494+ x_cpu , x_auto_mix_inference , x_auto_mix_train , x_man_bf16 , x_auto_mix_train_bf16 = _gen_tensor (
495+ rand_seed , (2 , 5 , 10 , 10 ))
496+
497+ op_cpu , op_auto_mix_inference , op_auto_mix_train , op_man_bf16 , op_auto_mix_train_bf16 = _gen_op (
498+ rand_seed , torch .nn .LayerNorm ([10 , 10 ]), is_bn = True )
499+
500+ ref_cpu = op_cpu (x_cpu )
501+ with AutoDNNL (True ), AutoMixPrecision (False ):
502+ res_bf16 = op_man_bf16 (x_man_bf16 )
503+ self .assertEqual (res_bf16 .dtype , torch .bfloat16 )
504+
505+ # FW inference
506+ with AutoMixPrecision (True , train = False ):
507+ self .assertEqual (x_auto_mix_inference .dtype , torch .float )
508+ self .assertFalse (ipex .core .is_bf16_dil_tensor (x_auto_mix_inference ))
509+ res_auto_mix_inference = op_auto_mix_inference (x_auto_mix_inference )
510+ self .assertEqual (res_auto_mix_inference .dtype , torch .float )
511+ self .assertEqual (x_auto_mix_inference .dtype , torch .float )
512+ self .assertTrue (ipex .core .is_bf16_dil_tensor (res_auto_mix_inference ))
513+ self .assertTrue (ipex .core .is_bf16_dil_tensor (x_auto_mix_inference ))
514+ self .assertEqual (res_bf16 .float (), res_auto_mix_inference )
515+
516+ # FW train (input is not bf16 dil tensor)
517+ with AutoMixPrecision (True , train = True ):
518+ self .assertEqual (x_auto_mix_train .dtype , torch .float )
519+ self .assertFalse (ipex .core .is_bf16_dil_tensor (x_auto_mix_train ))
520+ res_auto_mix_train = op_auto_mix_train (x_auto_mix_train )
521+ self .assertEqual (res_auto_mix_train .dtype , torch .float )
522+ self .assertEqual (x_auto_mix_train .dtype , torch .float )
523+ self .assertFalse (ipex .core .is_bf16_dil_tensor (res_auto_mix_train ))
524+ self .assertFalse (ipex .core .is_bf16_dil_tensor (x_auto_mix_train ))
525+ self .assertEqual (ref_cpu , res_auto_mix_train )
526+
527+ # FW train (input is bf16 dil tensor)
528+ with AutoMixPrecision (True , train = True ):
529+ self .assertEqual (x_auto_mix_train_bf16 .dtype , torch .float )
530+ self .assertTrue (ipex .core .is_bf16_dil_tensor (x_auto_mix_train_bf16 ))
531+ res_auto_mix_train_bf16 = op_auto_mix_train_bf16 (x_auto_mix_train_bf16 )
532+ self .assertEqual (res_auto_mix_train_bf16 .dtype , torch .float )
533+ self .assertEqual (x_auto_mix_train_bf16 .dtype , torch .float )
534+ self .assertTrue (ipex .core .is_bf16_dil_tensor (res_auto_mix_train_bf16 ))
535+ self .assertTrue (ipex .core .is_bf16_dil_tensor (x_auto_mix_train_bf16 ))
536+ self .assertEqual (res_bf16 .float (), res_auto_mix_train_bf16 )
537+
538+ def test_layer_norm_backward (self ):
539+ rand_seed = int (get_rand_seed ())
540+ print ("{} rand sed: {}" .format (sys ._getframe ().f_code .co_name , rand_seed ))
541+ x_cpu , _ , x_auto_mix , x_man_bf16 , x_auto_mix_bf16 = _gen_tensor (rand_seed , (2 , 5 , 10 , 10 ), is_forward = False )
542+
543+ op_cpu , _ , op_auto_mix , op_man_bf16 , op_auto_mix_bf16 = _gen_op (rand_seed , torch .nn .LayerNorm ([10 , 10 ]), is_bn = True , is_forward = False )
544+
545+ out_cpu = op_cpu (x_cpu ).sum ()
546+ out_cpu .backward ()
547+ with AutoDNNL (True ), AutoMixPrecision (False , train = True ):
548+ out_man_bf16 = op_man_bf16 (x_man_bf16 ).sum ()
549+ out_man_bf16 .backward ()
550+ self .assertEqual (x_man_bf16 .grad .dtype , torch .bfloat16 )
551+ self .assertEqual (x_cpu .grad .bfloat16 ().float (), x_man_bf16 .grad , 1e-2 )
552+
553+ # BW train (input is not bf16 dil tensor)
554+ with AutoMixPrecision (True , train = True ):
555+ self .assertEqual (x_auto_mix .dtype , torch .float )
556+ self .assertFalse (ipex .core .is_bf16_dil_tensor (x_auto_mix ))
557+ out_auto_mix = op_auto_mix (x_auto_mix ).sum ()
558+ out_auto_mix .backward ()
559+ self .assertEqual (x_auto_mix .grad .dtype , torch .float )
560+ self .assertFalse (ipex .core .is_bf16_dil_tensor (x_auto_mix .grad ))
561+ self .assertEqual (x_cpu .grad , x_auto_mix .grad )
562+
563+ # BW train (input is bf16 dil tensor)
564+ with AutoMixPrecision (True , train = True ):
565+ self .assertEqual (x_auto_mix_bf16 .dtype , torch .float )
566+ self .assertTrue (ipex .core .is_bf16_dil_tensor (x_auto_mix_bf16 ))
567+ out_auto_mix_bf16 = op_auto_mix_bf16 (x_auto_mix_bf16 ).sum ()
568+ out_auto_mix_bf16 .backward ()
569+ self .assertEqual (x_auto_mix_bf16 .grad .dtype , torch .float )
570+ self .assertTrue (ipex .core .is_bf16_dil_tensor (x_auto_mix_bf16 .grad ))
571+ self .assertEqual (x_man_bf16 .grad .float (), x_auto_mix_bf16 .grad )
572+
489573class TestRelu (TestCase ):
490574 def test_relu (self ):
491575 rand_seed = int (get_rand_seed ())
0 commit comments