@@ -67,108 +67,79 @@ mse_loss_backward_batch_rule(
67
67
return std::make_tuple (result, 0 );
68
68
};
69
69
70
- std::tuple<at::Tensor,optional<int64_t >,at::Tensor,optional<int64_t >>
71
- nll_loss_forward_self_target_batch_rule (
72
- const at::Tensor & self, optional<int64_t > self_bdim,
73
- const at::Tensor & target, optional<int64_t > target_bdim,
74
- int64_t reduction) {
75
- TORCH_INTERNAL_ASSERT (self.dim () == 3 && target.dim () == 2 );
76
-
77
- if (reduction == Reduction::None) {
78
- int64_t batch_size = self.size (*self_bdim);
79
- auto self_ = reshape_dim_into (*self_bdim, 0 , self);
80
- auto target_ = reshape_dim_into (*target_bdim, 0 , target);
81
- auto result = at::nll_loss_forward (self_, target_, nullopt, reduction, -100 );
82
- return std::make_tuple (
83
- reshape_dim_outof (0 , batch_size, std::get<0 >(result)), 0 ,
84
- std::get<1 >(result), nullopt
85
- );
86
- } else if (reduction == Reduction::Sum) {
87
- int64_t batch_size = self.size (*self_bdim);
88
- auto self_ = reshape_dim_into (*self_bdim, 0 , self);
89
- auto target_ = reshape_dim_into (*target_bdim, 0 , target);
90
- auto res = at::nll_loss_forward (self_, target_, nullopt, Reduction::None, -100 );
91
- auto output = std::get<0 >(res);
92
- output = reshape_dim_outof (0 , batch_size, output);
93
- auto total_weight = self_.new_full ({}, output.size (-1 ));
94
- return std::make_tuple (
95
- output.sum (-1 ), 0 ,
96
- // NB: total_weight = 0 after Reduction::None
97
- total_weight, nullopt
98
- );
99
- } else if (reduction == Reduction::Mean) {
100
- int64_t batch_size = self.size (*self_bdim);
101
- auto self_ = reshape_dim_into (*self_bdim, 0 , self);
102
- auto target_ = reshape_dim_into (*target_bdim, 0 , target);
103
- auto res = at::nll_loss_forward (self_, target_, nullopt, Reduction::None, -100 );
104
- auto output = std::get<0 >(res);
105
- output = reshape_dim_outof (0 , batch_size, output);
106
- auto total_weight = self_.new_full ({}, output.size (-1 ));
107
- return std::make_tuple (
108
- output.mean (-1 ), 0 ,
109
- // NB: total_weight = 0 after Reduction::None
110
- total_weight, nullopt
111
- );
112
- }
113
- TORCH_INTERNAL_ASSERT (false );
114
- }
115
-
116
- std::tuple<at::Tensor,at::Tensor> nll_loss_forward_plumbing (
117
- const at::Tensor & self,
118
- const at::Tensor & target,
119
- const c10::optional<at::Tensor> & weight,
70
+ std::tuple<Tensor, Tensor> nll_loss_forward_decomposition (
71
+ const Tensor & self,
72
+ const Tensor & target,
73
+ const c10::optional<Tensor> & weight,
120
74
int64_t reduction, int64_t ignore_index) {
121
- auto maybe_layer = maybeCurrentDynamicLayer ();
122
- TORCH_INTERNAL_ASSERT (maybe_layer.has_value ());
123
- int64_t cur_level = maybe_layer->layerId ();
124
-
125
- Tensor self_value;
126
- optional<int64_t > self_bdim;
127
- std::tie (self_value, self_bdim) = unwrapTensorAtLevel (self, cur_level);
128
75
129
- Tensor target_value;
130
- optional<int64_t > target_bdim;
131
- std::tie (target_value, target_bdim) = unwrapTensorAtLevel (target, cur_level);
132
-
133
- optional<Tensor> weight_value;
134
- optional<int64_t > weight_bdim;
135
- if (weight) {
136
- std::tie (weight_value, weight_bdim) = unwrapTensorAtLevel (*weight, cur_level);
76
+ bool has_ignore_index = ignore_index >= 0 ;
77
+ if (has_ignore_index) {
78
+ // fallback
79
+ if (target.dim () > 1 ) {
80
+ static auto op = c10::Dispatcher::singleton ()
81
+ .findSchemaOrThrow (" aten::nll_loss_nd" , " " );
82
+ return slow_fallback<Tensor, Tensor>(op, {self, target, weight, reduction, ignore_index});
83
+ } else {
84
+ static auto op = c10::Dispatcher::singleton ()
85
+ .findSchemaOrThrow (" aten::nll_loss_forward" , " " );
86
+ return slow_fallback<Tensor, Tensor>(op, {self, target, weight, reduction, ignore_index});
87
+ }
137
88
}
89
+ // self can be [N, C, ...] or [C]
90
+ // target can be [N, ...] or []
138
91
139
- if (self_bdim && target_bdim && (!weight || !weight->defined ()) && ignore_index < 0 ) {
140
- c10::impl::ExcludeDispatchKeyGuard guard (kBatchedKey );
141
- auto results = nll_loss_forward_self_target_batch_rule (
142
- self_value, self_bdim, target_value, target_bdim, reduction);
143
- return std::make_tuple (
144
- makeBatched (std::get<0 >(results), std::get<1 >(results), cur_level),
145
- makeBatched (std::get<2 >(results), std::get<3 >(results), cur_level)
146
- );
92
+ int64_t channel_dim = 1 ;
93
+ if (self.dim () < 2 ) {
94
+ channel_dim = 0 ;
147
95
}
148
-
149
- if ((!weight || !weight->defined ()) && ignore_index < 0 ) {
150
- // Decomposition: gather to get unreduced loss. 1 is for the C dim, that's always 1.
151
- // gather can handle arbitrary strides so it's a good candidate for a decomposition.
152
- auto target_ = target.unsqueeze (1 );
153
- auto result = at::gather (self, 1 , target_).squeeze (1 );
154
- auto total_weight = at::full (
155
- {}, result.numel (), self.scalar_type (),
156
- self.layout (), self.device (), nullopt);
157
-
158
- // Apply the reduction
159
- switch (reduction) {
160
- case Reduction::None:
161
- return std::make_tuple (-result, total_weight);
162
- case Reduction::Sum:
163
- return std::make_tuple (-result.sum (), total_weight);
164
- case Reduction::Mean:
165
- return std::make_tuple (-result.mean (), total_weight);
96
+ auto self_ = self;
97
+ Tensor weight_;
98
+
99
+ if (weight && weight->defined ()) {
100
+ // Here is a specific case with reduction mean and non-batched tensors
101
+ // https://github.com/pytorch/pytorch/issues/61309
102
+ // In this case weight is cancelled: w * x[t] / w -> x[t]
103
+ if (!(reduction == Reduction::Mean && self_.dim () < 2 )) {
104
+ // reshape weights to [1, C, 1, ..., 1]
105
+ auto shape = weight->sizes ();
106
+ VmapDimVector new_shape (self_.dim (), 1 );
107
+ new_shape[channel_dim] = shape[0 ];
108
+ weight_ = weight->reshape (new_shape);
109
+ self_ = self_ * weight_;
166
110
}
167
111
}
112
+ auto target_ = target.unsqueeze (channel_dim);
113
+ // target can be [N, 1, ...] or [1]
114
+
115
+ auto result = -at::gather (self_, channel_dim, target_).squeeze (channel_dim);
116
+ auto total_weight = at::full (
117
+ {}, result.numel (), self_.scalar_type (),
118
+ self_.layout (), self_.device (), nullopt);
119
+
120
+ // Apply the reduction
121
+ if (result.dim () > 0 ) {
122
+ if (reduction == Reduction::Sum) {
123
+ result = result.sum ();
124
+ } else if (reduction == Reduction::Mean) {
125
+ if (!weight || !weight->defined ()) {
126
+ result = result.mean ();
127
+ } else {
128
+ TORCH_INTERNAL_ASSERT (weight_.defined ());
129
+ weight_ = weight_.expand (self_.sizes ());
130
+ auto wsum = at::gather (weight_, channel_dim, target_).squeeze (channel_dim);
131
+ wsum = wsum.sum ();
132
+ result = result.sum () / wsum;
133
+ total_weight = wsum;
134
+ }
135
+ }
136
+ } else if (reduction == Reduction::Mean && weight && weight->defined ()) {
137
+ // here weight is [C] and target is [1]
138
+ auto wsum = at::gather (*weight, channel_dim, target_).squeeze (channel_dim);
139
+ total_weight = wsum.sum ();
140
+ }
168
141
169
- static auto op = c10::Dispatcher::singleton ()
170
- .findSchemaOrThrow (" aten::nll_loss_forward" , " " );
171
- return slow_fallback<Tensor,Tensor>(op, {self, target, weight, reduction, ignore_index});
142
+ return std::make_tuple (result, total_weight);
172
143
}
173
144
174
145
std::tuple<at::Tensor,optional<int64_t >>
@@ -272,7 +243,7 @@ at::Tensor nll_loss_backward_plumbing(
272
243
273
244
274
245
TORCH_LIBRARY_IMPL (aten, FT_BATCHED_KEY, m) {
275
- m.impl (" nll_loss_forward" , nll_loss_forward_plumbing );
246
+ m.impl (" nll_loss_forward" , nll_loss_forward_decomposition );
276
247
m.impl (" nll_loss_backward" , nll_loss_backward_plumbing);
277
248
VMAP_SUPPORT (" mse_loss" , mse_loss_batch_rule);
278
249
VMAP_SUPPORT (" mse_loss_backward" , mse_loss_backward_batch_rule);
0 commit comments