@@ -34,7 +34,10 @@ mse_loss_batch_rule(const at::Tensor& self, optional<int64_t> self_bdim, const a
34
34
if (result.dim () == 1 ) {
35
35
return std::make_tuple (result, 0 );
36
36
} else if (reduction == Reduction::None) {
37
- return std::make_tuple (result, 0 );
37
+ DimVector end_shape;
38
+ const auto batched_elem = self_bdim.has_value () ?
39
+ moveBatchDimToFront (self, self_bdim) : moveBatchDimToFront (target, target_bdim);
40
+ return std::make_tuple (result.reshape (batched_elem.sizes ()), 0 );
38
41
} else if (reduction == Reduction::Sum) {
39
42
return std::make_tuple (result.sum (-1 ), 0 );
40
43
} else if (reduction == Reduction::Mean) {
@@ -43,28 +46,18 @@ mse_loss_batch_rule(const at::Tensor& self, optional<int64_t> self_bdim, const a
43
46
TORCH_INTERNAL_ASSERT (false );
44
47
};
45
48
46
- std::tuple< at::Tensor,optional< int64_t >>
49
+ at::Tensor
47
50
mse_loss_backward_batch_rule (
48
- const at::Tensor& grad_output, optional< int64_t > grad_output_bdim,
49
- const at::Tensor& self, optional< int64_t > self_bdim,
50
- const at::Tensor& target, optional< int64_t > target_bdim,
51
+ const at::Tensor& grad_output,
52
+ const at::Tensor& self,
53
+ const at::Tensor& target,
51
54
int64_t reduction) {
52
- auto grad_output_ = moveBatchDimToFront (grad_output, grad_output_bdim);
53
- auto self_ = moveBatchDimToFront (self, self_bdim);
54
- auto target_ = moveBatchDimToFront (target, target_bdim);
55
- if (reduction != Reduction::None && grad_output_bdim.has_value ()) {
56
- // grad_output_ is of shape [N]. Input is of shape [N?, ...].
57
- // We need to view grad_output_ as shape [N, ...].
58
- auto self_rank_without_bdim = rankWithoutBatchDim (self, self_bdim);
59
- DimVector view_shape (self_rank_without_bdim + 1 , 1 );
60
- view_shape[0 ] = grad_output_.size (0 );
61
- grad_output_ = grad_output_.view (view_shape);
62
- }
63
- auto result = at::mse_loss_backward (grad_output_, self_, target_, Reduction::None);
55
+
56
+ const auto result = 2 . * (self - target) * grad_output;
64
57
if (reduction == Reduction::Mean) {
65
- return std::make_tuple ( result / numelWithoutBatchDim ( self, self_bdim), 0 );
58
+ return result / self. numel ( );
66
59
}
67
- return std::make_tuple ( result, 0 ) ;
60
+ return result;
68
61
};
69
62
70
63
static Tensor apply_loss_reduction (const at::Tensor& unreduced, int64_t reduction) {
@@ -303,7 +296,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
303
296
m.impl (" nll_loss_backward" , nll_loss_backward_decomposition);
304
297
m.impl (" nll_loss2d_backward" , nll_loss_backward_decomposition);
305
298
VMAP_SUPPORT (mse_loss, mse_loss_batch_rule);
306
- VMAP_SUPPORT ( mse_loss_backward, mse_loss_backward_batch_rule);
299
+ m. impl ( " mse_loss_backward" , mse_loss_backward_batch_rule);
307
300
m.impl (" binary_cross_entropy" , binary_cross_entropy_plumbing);
308
301
m.impl (" binary_cross_entropy_backward" , binary_cross_entropy_backward_plumbing);
309
302
}
0 commit comments