@@ -47,6 +47,17 @@ def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
47
47
return min - torch .log1p (z ), buffer
48
48
49
49
50
+ def recompute_mean_var (input : Tensor , rstd : Tensor , inner_dim_indices : List [int ], keepdim : bool ):
51
+ # for most norm decompositions, it will be the same as the core version except for here.
52
+ # We recompute the mean and variance so that they track gradients through input
53
+
54
+ mean = torch .mean (input , dim = inner_dim_indices , keepdim = keepdim )
55
+ var = torch .var (input , dim = inner_dim_indices , unbiased = False , keepdim = keepdim )
56
+ eps = torch .pow (1 / rstd , 2 ) - var # this makes me so sad inside
57
+ eps = eps .detach ()
58
+ rstd = 1 / torch .sqrt (var + eps )
59
+ return mean , rstd
60
+
50
61
@register_decomposition_for_jvp (aten .native_layer_norm_backward )
51
62
def native_layer_norm_backward (
52
63
grad_out : Tensor ,
@@ -80,13 +91,7 @@ def native_layer_norm_backward(
80
91
input .new_zeros (input_shape [axis :]),
81
92
)
82
93
83
- # this is exactly the same as the other decomposition except for here. We recompute the mean and variance
84
- # so that they track gradients through input
85
- mean_ = torch .mean (input , dim = inner_dim_indices , keepdim = True )
86
- var = torch .var (input , dim = inner_dim_indices , unbiased = False , keepdim = True )
87
- eps = torch .pow (1 / rstd , 2 ) - var # this makes me so sad inside
88
- eps = eps .detach ()
89
- rstd_ = 1 / torch .sqrt (var + eps )
94
+ mean_ , rstd_ = recompute_mean_var (input , rstd , inner_dim_indices , keepdim = True )
90
95
91
96
x_hat = (input - mean_ ) * rstd_
92
97
if weight is not None :
@@ -128,3 +133,84 @@ def native_layer_norm_backward(
128
133
d_bias = torch .zeros (()) # should be None but doesn't work with vjp
129
134
130
135
return (d_input , d_weight , d_bias )
136
+
137
+
138
+ def prod (x : List [int ]):
139
+ r = 1
140
+ for i in x :
141
+ r *= i
142
+ return r
143
+
144
+
145
+ @register_decomposition (aten .native_batch_norm_backward ) # @register_decomposition_for_jvp after in core
146
+ def native_batch_norm_backward (
147
+ grad_out : Tensor ,
148
+ input : Tensor ,
149
+ weight : Optional [Tensor ],
150
+ running_mean : Optional [Tensor ],
151
+ running_var : Optional [Tensor ],
152
+ save_mean : Optional [Tensor ],
153
+ save_invstd : Optional [Tensor ],
154
+ train : bool ,
155
+ eps : float ,
156
+ output_mask : List [bool ],
157
+ ) -> Tuple [Tensor , Optional [Tensor ], Optional [Tensor ]]:
158
+ input_shape = input .shape
159
+ input_rank = input .dim ()
160
+ assert input_rank >= 2 , "rank of the input must be at least 2"
161
+
162
+ axis = 1
163
+ num_features = prod (input_shape ) / input_shape [axis ]
164
+ mean = save_mean
165
+ invstd = save_invstd
166
+ if train :
167
+ assert save_mean is not None and save_invstd is not None , "when train=True, save_mean and save_invstd are required"
168
+
169
+ reduciton_dims = [0 ] + list (range (2 , input .dim ()))
170
+ assert invstd is not None # for typing
171
+ mean , invstd = recompute_mean_var (input , invstd , reduciton_dims , keepdim = False )
172
+ else :
173
+ assert running_mean is not None and running_var is not None
174
+ mean = running_mean
175
+ invstd = torch .rsqrt (running_var + eps )
176
+
177
+ broadcast_mask = [1 ] * input_rank
178
+ broadcast_mask [axis ] = input_shape [axis ]
179
+
180
+ reduction_axes : List [int ] = []
181
+ for i in range (input_rank ):
182
+ if i != axis :
183
+ reduction_axes .append (i )
184
+
185
+ mean = torch .reshape (mean , broadcast_mask )
186
+ norm = 1.0 / num_features
187
+ grad_output_sum = torch .sum (grad_out , reduction_axes )
188
+ dot_p = torch .sum (grad_out * (input - mean ), reduction_axes )
189
+
190
+ grad_mean = torch .reshape (grad_output_sum * norm , broadcast_mask )
191
+ proj_scale = torch .reshape (torch .mul (dot_p * norm , invstd * invstd ), broadcast_mask )
192
+
193
+ if weight is None :
194
+ grad_scale = torch .reshape (invstd , broadcast_mask ) * 1.0
195
+ else :
196
+ grad_scale = torch .reshape (invstd * weight , broadcast_mask )
197
+
198
+ if train :
199
+ proj = (input - mean ) * proj_scale
200
+ grad_input = ((grad_out - proj ) - grad_mean ) * grad_scale
201
+ else :
202
+ grad_input = grad_out * grad_scale
203
+
204
+ if output_mask [1 ]:
205
+ grad_weight = dot_p * invstd
206
+ elif weight is not None :
207
+ grad_weight = torch .zeros_like (weight ) # should be None but doesn't work with vjp
208
+ else :
209
+ grad_weight = torch .zeros (()) # should be None but doesn't work with vjp
210
+
211
+ if output_mask [2 ]:
212
+ grad_bias = grad_output_sum
213
+ else :
214
+ grad_bias = torch .zeros_like (grad_output_sum ) # should be None but doesn't work with vjp
215
+
216
+ return (grad_input , grad_weight , grad_bias )
0 commit comments