@@ -439,6 +439,7 @@ def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int =
439
439
return out
440
440
return beta * self + out
441
441
442
+
442
443
@register_decomposition (aten .native_layer_norm_backward )
443
444
def native_layer_norm_backward (grad_out : Tensor , input : Tensor , normalized_shape : List [int ], mean : Tensor , rstd : Tensor , weight : Optional [Tensor ], bias : Optional [Tensor ], output_mask : List [bool ]) -> Tuple [Tensor , Tensor , Tensor ]:
444
445
input_shape = input .shape
@@ -447,18 +448,18 @@ def native_layer_norm_backward(grad_out: Tensor, input: Tensor, normalized_shape
447
448
axis = input_ndim - len (normalized_shape )
448
449
inner_dims = input_shape [axis :]
449
450
outer_dims = input_shape [:axis ]
450
- inner_dim_indices = []
451
- outer_dim_indices = []
451
+ inner_dim_indices : List [ int ] = []
452
+ outer_dim_indices : List [ int ] = []
452
453
for i in range (input_ndim ):
453
454
if (i >= axis ):
454
455
inner_dim_indices .append (i )
455
456
else :
456
457
outer_dim_indices .append (i )
457
458
458
- N = float ( prod (inner_dims ))
459
+ N = prod (inner_dims )
459
460
M = prod (outer_dims )
460
- if M <= 0 or N <= 0.0 :
461
- return (aten .new_empty (input , input_shape ), aten .new_zeros (input [ axis :] , input_shape [axis :]), aten .new_zeros (input [ axis :] , input_shape [axis :]))
461
+ if M <= 0 or N <= 0 :
462
+ return (aten .new_zeros (input , input_shape ), aten .new_zeros (input , input_shape [axis :]), aten .new_zeros (input , input_shape [axis :]))
462
463
463
464
x_hat = aten .mul (aten .sub (input , mean ), rstd )
464
465
if weight is not None :
@@ -476,23 +477,23 @@ def native_layer_norm_backward(grad_out: Tensor, input: Tensor, normalized_shape
476
477
if output_mask [0 ]:
477
478
d_input = aten .mul (aten .div (rstd , N ), inner )
478
479
else :
479
- d_input = None
480
+ d_input = aten . new_empty ( input , ( 0 ,))
480
481
481
482
if output_mask [1 ] and weight is not None :
482
483
if len (outer_dim_indices ) > 0 :
483
484
d_weight = aten .sum (aten .mul (grad_out , x_hat ), outer_dim_indices , False )
484
485
else :
485
486
d_weight = aten .mul (grad_out , x_hat )
486
487
else :
487
- d_weight = None
488
+ d_weight = aten . new_empty ( input , ( 0 ,))
488
489
489
490
if output_mask [2 ] and bias is not None :
490
491
if len (outer_dim_indices ) > 0 :
491
492
d_bias = aten .sum (grad_out , outer_dim_indices , False )
492
493
else :
493
494
d_bias = grad_out
494
495
else :
495
- d_bias = None
496
+ d_bias = aten . new_empty ( input , ( 0 ,))
496
497
return (d_input , d_weight , d_bias )
497
498
498
499
# @register_decomposition(aten.addmm)
0 commit comments