Skip to content

Commit 26d4cfc

Browse files
author
Samantha Andow
authored
Fix MSE forward, use decomposition for MSE backward (#860)
* use decomposition for mse backward * only reshape if there was no reduction * add tests, fix shape of mse loss forward * remove mse xfail * simplify backwards rule
1 parent 6819a15 commit 26d4cfc

File tree

3 files changed

+44
-22
lines changed

3 files changed

+44
-22
lines changed

functorch/csrc/BatchRulesLoss.cpp

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@ mse_loss_batch_rule(const at::Tensor& self, optional<int64_t> self_bdim, const a
3434
if (result.dim() == 1) {
3535
return std::make_tuple(result, 0);
3636
} else if (reduction == Reduction::None) {
37-
return std::make_tuple(result, 0);
37+
DimVector end_shape;
38+
const auto batched_elem = self_bdim.has_value() ?
39+
moveBatchDimToFront(self, self_bdim) : moveBatchDimToFront(target, target_bdim);
40+
return std::make_tuple(result.reshape(batched_elem.sizes()), 0);
3841
} else if (reduction == Reduction::Sum) {
3942
return std::make_tuple(result.sum(-1), 0);
4043
} else if (reduction == Reduction::Mean) {
@@ -43,28 +46,18 @@ mse_loss_batch_rule(const at::Tensor& self, optional<int64_t> self_bdim, const a
4346
TORCH_INTERNAL_ASSERT(false);
4447
};
4548

46-
std::tuple<at::Tensor,optional<int64_t>>
49+
at::Tensor
4750
mse_loss_backward_batch_rule(
48-
const at::Tensor& grad_output, optional<int64_t> grad_output_bdim,
49-
const at::Tensor& self, optional<int64_t> self_bdim,
50-
const at::Tensor& target, optional<int64_t> target_bdim,
51+
const at::Tensor& grad_output,
52+
const at::Tensor& self,
53+
const at::Tensor& target,
5154
int64_t reduction) {
52-
auto grad_output_ = moveBatchDimToFront(grad_output, grad_output_bdim);
53-
auto self_ = moveBatchDimToFront(self, self_bdim);
54-
auto target_ = moveBatchDimToFront(target, target_bdim);
55-
if (reduction != Reduction::None && grad_output_bdim.has_value()) {
56-
// grad_output_ is of shape [N]. Input is of shape [N?, ...].
57-
// We need to view grad_output_ as shape [N, ...].
58-
auto self_rank_without_bdim = rankWithoutBatchDim(self, self_bdim);
59-
DimVector view_shape(self_rank_without_bdim + 1, 1);
60-
view_shape[0] = grad_output_.size(0);
61-
grad_output_ = grad_output_.view(view_shape);
62-
}
63-
auto result = at::mse_loss_backward(grad_output_, self_, target_, Reduction::None);
55+
56+
const auto result = 2. * (self - target) * grad_output;
6457
if (reduction == Reduction::Mean) {
65-
return std::make_tuple(result / numelWithoutBatchDim(self, self_bdim), 0);
58+
return result / self.numel();
6659
}
67-
return std::make_tuple(result, 0);
60+
return result;
6861
};
6962

7063
static Tensor apply_loss_reduction(const at::Tensor& unreduced, int64_t reduction) {
@@ -303,7 +296,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
303296
m.impl("nll_loss_backward", nll_loss_backward_decomposition);
304297
m.impl("nll_loss2d_backward", nll_loss_backward_decomposition);
305298
VMAP_SUPPORT(mse_loss, mse_loss_batch_rule);
306-
VMAP_SUPPORT(mse_loss_backward, mse_loss_backward_batch_rule);
299+
m.impl("mse_loss_backward", mse_loss_backward_batch_rule);
307300
m.impl("binary_cross_entropy", binary_cross_entropy_plumbing);
308301
m.impl("binary_cross_entropy_backward", binary_cross_entropy_backward_plumbing);
309302
}

test/functorch_additional_op_db.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,37 @@ def generator():
185185
))
186186

187187

188+
def sample_inputs_mse_loss(op_info, device, dtype, requires_grad, **kwargs):
189+
def make_input(shape, requires_grad=requires_grad):
190+
return make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad)
191+
192+
rhs_requires_grad = kwargs.get('rhs_requires_grad', requires_grad)
193+
S = 5
194+
195+
shapes = ((S, S), (S, S, S), (S, S, S, S))
196+
reductions = ("none", "mean", "sum")
197+
198+
for shape, reduction in itertools.product(shapes, reductions):
199+
yield SampleInput(make_input(shape),
200+
args=(make_input(shape, requires_grad=rhs_requires_grad),),
201+
kwargs={"reduction": reduction})
202+
203+
204+
additional_op_db.append(
205+
OpInfo(
206+
"nn.functional.mse_loss",
207+
variant_test_name="functorch",
208+
sample_inputs_func=sample_inputs_mse_loss,
209+
supports_out=False,
210+
supports_forward_ad=True,
211+
supports_fwgrad_bwgrad=True,
212+
dtypes=floating_types_and(torch.float16),
213+
backward_dtypes=floating_types(),
214+
dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
215+
backward_dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
216+
))
217+
218+
188219
def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs):
189220
S = 5
190221
test_args = [

test/test_ops.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,8 +1322,6 @@ def test_extremal_numerics_l1_loss(self, device):
13221322
cotangents = torch.randn_like(result, device=device)
13231323
self._compare_jacobians_of_vjp(torch.nn.functional.l1_loss, (cotangents, input, target))
13241324

1325-
# ("https://github.com/pytorch/functorch/issues/858")
1326-
@unittest.expectedFailure
13271325
def test_extremal_numerics_mse_loss(self, device):
13281326
N, C, H, W = 3, 4, 5, 6
13291327
shapes = ((N, C), (N, C, H), (N, C, H, W))

0 commit comments

Comments
 (0)