@@ -73,19 +73,6 @@ std::tuple<Tensor, Tensor> nll_loss_forward_decomposition(
73
73
const c10::optional<Tensor> & weight,
74
74
int64_t reduction, int64_t ignore_index) {
75
75
76
- bool has_ignore_index = ignore_index >= 0 ;
77
- if (has_ignore_index) {
78
- // fallback
79
- if (target.dim () > 1 ) {
80
- static auto op = c10::Dispatcher::singleton ()
81
- .findSchemaOrThrow (" aten::nll_loss_nd" , " " );
82
- return slow_fallback<Tensor, Tensor>(op, {self, target, weight, reduction, ignore_index});
83
- } else {
84
- static auto op = c10::Dispatcher::singleton ()
85
- .findSchemaOrThrow (" aten::nll_loss_forward" , " " );
86
- return slow_fallback<Tensor, Tensor>(op, {self, target, weight, reduction, ignore_index});
87
- }
88
- }
89
76
// self can be [N, C, ...] or [C]
90
77
// target can be [N, ...] or []
91
78
@@ -117,17 +104,35 @@ std::tuple<Tensor, Tensor> nll_loss_forward_decomposition(
117
104
{}, result.numel (), self_.scalar_type (),
118
105
self_.layout (), self_.device (), nullopt);
119
106
107
+ bool has_ignore_index = ignore_index >= 0 ;
108
+ Tensor ignore_index_mask;
109
+ if (has_ignore_index) {
110
+ ignore_index_mask = target != ignore_index;
111
+ result = result * ignore_index_mask;
112
+ total_weight = ignore_index_mask.sum ().to (self_);
113
+ }
114
+
120
115
// Apply the reduction
121
116
if (result.dim () > 0 ) {
122
117
if (reduction == Reduction::Sum) {
123
118
result = result.sum ();
124
119
} else if (reduction == Reduction::Mean) {
125
120
if (!weight || !weight->defined ()) {
126
- result = result.mean ();
121
+ if (has_ignore_index) {
122
+ TORCH_INTERNAL_ASSERT (ignore_index_mask.defined ());
123
+ // total_weight is ignore_index_mask.sum()
124
+ result = result.sum () / total_weight;
125
+ } else {
126
+ result = result.mean ();
127
+ }
127
128
} else {
128
129
TORCH_INTERNAL_ASSERT (weight_.defined ());
129
130
weight_ = weight_.expand (self_.sizes ());
130
131
auto wsum = at::gather (weight_, channel_dim, target_).squeeze (channel_dim);
132
+ if (has_ignore_index) {
133
+ TORCH_INTERNAL_ASSERT (ignore_index_mask.defined ());
134
+ wsum = wsum * ignore_index_mask;
135
+ }
131
136
wsum = wsum.sum ();
132
137
result = result.sum () / wsum;
133
138
total_weight = wsum;
@@ -136,6 +141,10 @@ std::tuple<Tensor, Tensor> nll_loss_forward_decomposition(
136
141
} else if (reduction == Reduction::Mean && weight && weight->defined ()) {
137
142
// here weight is [C] and target is [1]
138
143
auto wsum = at::gather (*weight, channel_dim, target_).squeeze (channel_dim);
144
+ if (has_ignore_index) {
145
+ TORCH_INTERNAL_ASSERT (ignore_index_mask.defined ());
146
+ wsum = wsum * ignore_index_mask;
147
+ }
139
148
total_weight = wsum.sum ();
140
149
}
141
150
@@ -244,6 +253,7 @@ at::Tensor nll_loss_backward_plumbing(
244
253
245
254
TORCH_LIBRARY_IMPL (aten, FT_BATCHED_KEY, m) {
246
255
m.impl (" nll_loss_forward" , nll_loss_forward_decomposition);
256
+ m.impl (" nll_loss2d_forward" , nll_loss_forward_decomposition);
247
257
m.impl (" nll_loss_backward" , nll_loss_backward_plumbing);
248
258
VMAP_SUPPORT (" mse_loss" , mse_loss_batch_rule);
249
259
VMAP_SUPPORT (" mse_loss_backward" , mse_loss_backward_batch_rule);
0 commit comments