Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit deb706d

Browse files
author
Samantha Andow
authored
fix gelu_backward (#640)
1 parent 752b27b commit deb706d

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

functorch/csrc/BatchRulesBinaryOps.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,23 @@ std::tuple<Tensor,optional<int64_t>> where_self_batch_rule(
203203
return std::make_tuple(at::where(condition_, self_, other_), 0);
204204
}
205205

206+
std::tuple<Tensor, optional<int64_t>> gelu_backward_batch_rule(
207+
const Tensor& grad_out, optional<int64_t> grad_out_bdim, const Tensor& input, optional<int64_t> input_bdim,
208+
c10::string_view approximate) {
209+
210+
// repeat the preprocessing from _binary_pointwise_batch_rule
211+
const auto tensor_other = _binary_pointwise_helper(grad_out, grad_out_bdim, input, input_bdim);
212+
auto grad_out_ = std::get<0>(tensor_other);
213+
auto input_ = std::get<1>(tensor_other);
214+
215+
// gelu_backward doesn't broadcast well so we need to insist all inputs have a bdim
216+
const auto batch_size = get_bdim_size2(grad_out, grad_out_bdim, input, input_bdim);
217+
grad_out_ = ensure_has_bdim(grad_out_, grad_out_bdim.has_value(), batch_size);
218+
input_ = ensure_has_bdim(input_, input_bdim.has_value(), batch_size);
219+
220+
return std::make_tuple(at::gelu_backward(grad_out_, input_, approximate), 0);
221+
}
222+
206223
std::tuple<Tensor,optional<int64_t>> masked_select_batch_rule(
207224
const Tensor& self, optional<int64_t> self_bdim,
208225
const Tensor& mask, optional<int64_t> mask_bdim) {
@@ -399,7 +416,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
399416
BINARY_POINTWISE(leaky_relu_backward);
400417
BINARY_POINTWISE(logit_backward);
401418
POINTWISE_BOXED(log_sigmoid_backward);
402-
BINARY_POINTWISE(gelu_backward);
419+
VMAP_SUPPORT(gelu_backward, gelu_backward_batch_rule);
403420
BINARY_POINTWISE(sigmoid_backward);
404421
POINTWISE_BOXED(softplus_backward);
405422
BINARY_POINTWISE(softshrink_backward);

test/test_ops.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -625,8 +625,6 @@ def vjp_of_vjp(*args_and_cotangents):
625625
xfail('index_put', ''),
626626
xfail('lu_solve'),
627627
xfail('index_copy'),
628-
xfail('nn.functional.gelu', device_type='cpu'),
629-
630628
xfail('linalg.lu_factor', ''),
631629
})
632630

@@ -703,9 +701,6 @@ def test_vmapvjp(self, device, dtype, op):
703701
xfail('nanquantile'),
704702
xfail('quantile'),
705703
706-
# RuntimeError: vmap: inplace arithmetic(self, *extra_args)
707-
xfail('nn.functional.gelu', device_type='cpu'),
708-
709704
# Not implemented
710705
xfail('scatter'),
711706
@@ -778,7 +773,6 @@ def test_vmapjvp(self, device, dtype, op):
778773
xfail('maximum'),
779774
xfail('linalg.householder_product'),
780775
xfail('tensor_split'),
781-
xfail('nn.functional.gelu', device_type='cpu'),
782776
xfail('quantile'),
783777
xfail('var_mean'),
784778
xfail('as_strided'),

0 commit comments

Comments
 (0)