@@ -46,20 +46,6 @@ mse_loss_batch_rule(const at::Tensor& self, optional<int64_t> self_bdim, const a
46
46
TORCH_INTERNAL_ASSERT (false );
47
47
};
48
48
49
- at::Tensor
50
- mse_loss_backward_batch_rule (
51
- const at::Tensor& grad_output,
52
- const at::Tensor& self,
53
- const at::Tensor& target,
54
- int64_t reduction) {
55
-
56
- const auto result = 2 . * (self - target) * grad_output;
57
- if (reduction == Reduction::Mean) {
58
- return result / self.numel ();
59
- }
60
- return result;
61
- };
62
-
63
49
static Tensor apply_loss_reduction (const at::Tensor& unreduced, int64_t reduction) {
64
50
if (reduction == at::Reduction::Mean) {
65
51
return unreduced.mean ();
@@ -296,7 +282,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
296
282
m.impl (" nll_loss_backward" , nll_loss_backward_decomposition);
297
283
m.impl (" nll_loss2d_backward" , nll_loss_backward_decomposition);
298
284
VMAP_SUPPORT (mse_loss, mse_loss_batch_rule);
299
- m. impl ( " mse_loss_backward " , mse_loss_backward_batch_rule);
285
+ // mse_loss_backwards uses a decomposition for its batch rule
300
286
m.impl (" binary_cross_entropy" , binary_cross_entropy_plumbing);
301
287
m.impl (" binary_cross_entropy_backward" , binary_cross_entropy_backward_plumbing);
302
288
}
0 commit comments