Skip to content

Commit 96dead4

Browse files
author
Samantha Andow
authored
make mse loss backward use decomp (#866)
1 parent 23e8c66 commit 96dead4

File tree

2 files changed

+2
-15
lines changed

2 files changed

+2
-15
lines changed

functorch/_src/eager_transforms.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,4 +1332,5 @@ def _register_python_decomposition_vmap(decomp):
13321332
_register_jit_decomposition(torch.ops.aten._softmax_backward_data.default)
13331333
_register_jit_decomposition(torch.ops.aten.log_sigmoid_forward.default)
13341334
_register_jit_decomposition(torch.ops.aten.native_layer_norm_backward.default)
1335+
_register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default)
13351336
_register_python_decomposition_vmap(torch.ops.aten.addr.default)

functorch/csrc/BatchRulesLoss.cpp

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,6 @@ mse_loss_batch_rule(const at::Tensor& self, optional<int64_t> self_bdim, const a
4646
TORCH_INTERNAL_ASSERT(false);
4747
};
4848

49-
at::Tensor
50-
mse_loss_backward_batch_rule(
51-
const at::Tensor& grad_output,
52-
const at::Tensor& self,
53-
const at::Tensor& target,
54-
int64_t reduction) {
55-
56-
const auto result = 2. * (self - target) * grad_output;
57-
if (reduction == Reduction::Mean) {
58-
return result / self.numel();
59-
}
60-
return result;
61-
};
62-
6349
static Tensor apply_loss_reduction(const at::Tensor& unreduced, int64_t reduction) {
6450
if (reduction == at::Reduction::Mean) {
6551
return unreduced.mean();
@@ -296,7 +282,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
296282
m.impl("nll_loss_backward", nll_loss_backward_decomposition);
297283
m.impl("nll_loss2d_backward", nll_loss_backward_decomposition);
298284
VMAP_SUPPORT(mse_loss, mse_loss_batch_rule);
299-
m.impl("mse_loss_backward", mse_loss_backward_batch_rule);
285+
// mse_loss_backwards uses a decomposition for its batch rule
300286
m.impl("binary_cross_entropy", binary_cross_entropy_plumbing);
301287
m.impl("binary_cross_entropy_backward", binary_cross_entropy_backward_plumbing);
302288
}

0 commit comments

Comments
 (0)