@@ -381,7 +381,7 @@ def native_layer_norm(input: Tensor, normalized_shape: List[int], weight: Option
381
381
if M > 0 :
382
382
input_reshaped = input .view (1 , M , - 1 )
383
383
else :
384
- return (input , aten .new_empty (input , (0 ,)), aten .new_empty (input , (0 ,)))
384
+ return (input , aten .new_zeros (input , (0 ,)), aten .new_zeros (input , (0 ,)))
385
385
386
386
# Unlike Batch Normalization, which applies scalar scale and bias for each
387
387
# entire channel/plane with the affine option, Layer Normalization applies
@@ -439,6 +439,72 @@ def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int =
439
439
return out
440
440
return beta * self + out
441
441
442
+ @register_decomposition (aten .native_layer_norm_backward )
443
+ def native_layer_norm_backward (grad_out : Tensor , input : Tensor , normalized_shape : List [int ], mean : Tensor , rstd : Tensor , weight : Optional [Tensor ], bias : Optional [Tensor ], output_mask : List [bool ]) -> Tuple [Tensor , Tensor , Tensor ]:
444
+ input_shape = input .shape
445
+ input_ndim = input .dim ()
446
+
447
+ axis = input_ndim - len (normalized_shape )
448
+ inner_dims = input_shape [axis :]
449
+ outer_dims = input_shape [:axis ]
450
+ inner_dim_indices = []
451
+ outer_dim_indices = []
452
+ for i in range (input_ndim ):
453
+ if (i >= axis ):
454
+ inner_dim_indices .append (i )
455
+ else :
456
+ outer_dim_indices .append (i )
457
+
458
+ N = float (prod (inner_dims ))
459
+ M = prod (outer_dims )
460
+ if M <= 0 or N <= 0.0 :
461
+ return (aten .new_empty (input , input_shape ), aten .new_zeros (input [axis :], input_shape [axis :]), aten .new_zeros (input [axis :], input_shape [axis :]))
462
+
463
+ x_hat = aten .mul (aten .sub (input , mean ), rstd )
464
+ if weight is not None :
465
+ grad_x_hat = aten .mul (grad_out , weight )
466
+ else :
467
+ grad_x_hat = grad_out
468
+ a = aten .mul (grad_x_hat , N )
469
+ b = aten .sum (grad_x_hat , inner_dim_indices , True )
470
+ c1 = aten .mul (grad_x_hat , x_hat )
471
+ c2 = aten .sum (c1 , inner_dim_indices , True )
472
+ c3 = aten .mul (x_hat , c2 )
473
+
474
+ inner = aten .sub (aten .sub (a , b ), c3 )
475
+
476
+ if output_mask [0 ]:
477
+ d_input = aten .mul (aten .div (rstd , N ), inner )
478
+ else :
479
+ d_input = None
480
+
481
+ if output_mask [1 ] and weight is not None :
482
+ if len (outer_dim_indices ) > 0 :
483
+ d_weight = aten .sum (aten .mul (grad_out , x_hat ), outer_dim_indices , False )
484
+ else :
485
+ d_weight = aten .mul (grad_out , x_hat )
486
+ else :
487
+ d_weight = None
488
+
489
+ if output_mask [2 ] and bias is not None :
490
+ if len (outer_dim_indices ) > 0 :
491
+ d_bias = aten .sum (grad_out , outer_dim_indices , False )
492
+ else :
493
+ d_bias = grad_out
494
+ else :
495
+ d_bias = None
496
+ return (d_input , d_weight , d_bias )
497
+
498
+ # @register_decomposition(aten.addmm)
499
+ # def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta=1, alpha=1):
500
+ # if not self.is_floating_point():
501
+ # beta = int(beta)
502
+ # alpha = int(alpha)
503
+ # out = alpha * aten.mm(mat1, mat2)
504
+ # if beta == 0:
505
+ # return out
506
+ # return beta * self + out
507
+
442
508
443
509
@register_decomposition (aten .clamp_min )
444
510
def clamp_min (self : Tensor , min : float ):
0 commit comments