@@ -203,6 +203,23 @@ std::tuple<Tensor,optional<int64_t>> where_self_batch_rule(
203
203
return std::make_tuple (at::where (condition_, self_, other_), 0 );
204
204
}
205
205
206
+ std::tuple<Tensor, optional<int64_t >> gelu_backward_batch_rule (
207
+ const Tensor& grad_out, optional<int64_t > grad_out_bdim, const Tensor& input, optional<int64_t > input_bdim,
208
+ c10::string_view approximate) {
209
+
210
+ // repeat the preprocessing from _binary_pointwise_batch_rule
211
+ const auto tensor_other = _binary_pointwise_helper (grad_out, grad_out_bdim, input, input_bdim);
212
+ auto grad_out_ = std::get<0 >(tensor_other);
213
+ auto input_ = std::get<1 >(tensor_other);
214
+
215
+ // gelu_backward doesn't broadcast well so we need to insist all inputs have a bdim
216
+ const auto batch_size = get_bdim_size2 (grad_out, grad_out_bdim, input, input_bdim);
217
+ grad_out_ = ensure_has_bdim (grad_out_, grad_out_bdim.has_value (), batch_size);
218
+ input_ = ensure_has_bdim (input_, input_bdim.has_value (), batch_size);
219
+
220
+ return std::make_tuple (at::gelu_backward (grad_out_, input_, approximate), 0 );
221
+ }
222
+
206
223
std::tuple<Tensor,optional<int64_t >> masked_select_batch_rule (
207
224
const Tensor& self, optional<int64_t > self_bdim,
208
225
const Tensor& mask, optional<int64_t > mask_bdim) {
@@ -399,7 +416,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
399
416
BINARY_POINTWISE (leaky_relu_backward);
400
417
BINARY_POINTWISE (logit_backward);
401
418
POINTWISE_BOXED (log_sigmoid_backward);
402
- BINARY_POINTWISE (gelu_backward);
419
+ VMAP_SUPPORT (gelu_backward, gelu_backward_batch_rule );
403
420
BINARY_POINTWISE (sigmoid_backward);
404
421
POINTWISE_BOXED (softplus_backward);
405
422
BINARY_POINTWISE (softshrink_backward);
0 commit comments