Skip to content

Commit 38d1161

Browse files
authored
Updated nll loss decomposition rule with ignore_index (#218)
* WIP on adding ignore index decomp rule * Updated nll_loss decomposition rule to take into account ignore_index * Recoded total_weight computation for has_ignore_index without using .item * Added required decompositions * Fixed nits
1 parent b107bab commit 38d1161

File tree

3 files changed

+27
-15
lines changed

3 files changed

+27
-15
lines changed

functorch/csrc/BatchRulesDecompositions.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
121121
OP_DECOMPOSE(var_mean);
122122
OP_DECOMPOSE2(var_mean, dim);
123123
OP_DECOMPOSE2(where, self);
124+
OP_DECOMPOSE(nll_loss_nd);
125+
OP_DECOMPOSE(nll_loss);
126+
OP_DECOMPOSE(nll_loss2d);
124127
}
125128

126129
}}

functorch/csrc/BatchRulesLoss.cpp

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -73,19 +73,6 @@ std::tuple<Tensor, Tensor> nll_loss_forward_decomposition(
7373
const c10::optional<Tensor> & weight,
7474
int64_t reduction, int64_t ignore_index) {
7575

76-
bool has_ignore_index = ignore_index >= 0;
77-
if (has_ignore_index) {
78-
// fallback
79-
if (target.dim() > 1) {
80-
static auto op = c10::Dispatcher::singleton()
81-
.findSchemaOrThrow("aten::nll_loss_nd", "");
82-
return slow_fallback<Tensor, Tensor>(op, {self, target, weight, reduction, ignore_index});
83-
} else {
84-
static auto op = c10::Dispatcher::singleton()
85-
.findSchemaOrThrow("aten::nll_loss_forward", "");
86-
return slow_fallback<Tensor, Tensor>(op, {self, target, weight, reduction, ignore_index});
87-
}
88-
}
8976
// self can be [N, C, ...] or [C]
9077
// target can be [N, ...] or []
9178

@@ -117,17 +104,35 @@ std::tuple<Tensor, Tensor> nll_loss_forward_decomposition(
117104
{}, result.numel(), self_.scalar_type(),
118105
self_.layout(), self_.device(), nullopt);
119106

107+
bool has_ignore_index = ignore_index >= 0;
108+
Tensor ignore_index_mask;
109+
if (has_ignore_index) {
110+
ignore_index_mask = target != ignore_index;
111+
result = result * ignore_index_mask;
112+
total_weight = ignore_index_mask.sum().to(self_);
113+
}
114+
120115
// Apply the reduction
121116
if (result.dim() > 0) {
122117
if (reduction == Reduction::Sum) {
123118
result = result.sum();
124119
} else if (reduction == Reduction::Mean) {
125120
if (!weight || !weight->defined()) {
126-
result = result.mean();
121+
if (has_ignore_index) {
122+
TORCH_INTERNAL_ASSERT(ignore_index_mask.defined());
123+
// total_weight is ignore_index_mask.sum()
124+
result = result.sum() / total_weight;
125+
} else {
126+
result = result.mean();
127+
}
127128
} else {
128129
TORCH_INTERNAL_ASSERT(weight_.defined());
129130
weight_ = weight_.expand(self_.sizes());
130131
auto wsum = at::gather(weight_, channel_dim, target_).squeeze(channel_dim);
132+
if (has_ignore_index) {
133+
TORCH_INTERNAL_ASSERT(ignore_index_mask.defined());
134+
wsum = wsum * ignore_index_mask;
135+
}
131136
wsum = wsum.sum();
132137
result = result.sum() / wsum;
133138
total_weight = wsum;
@@ -136,6 +141,10 @@ std::tuple<Tensor, Tensor> nll_loss_forward_decomposition(
136141
} else if (reduction == Reduction::Mean && weight && weight->defined()) {
137142
// here weight is [C] and target is [1]
138143
auto wsum = at::gather(*weight, channel_dim, target_).squeeze(channel_dim);
144+
if (has_ignore_index) {
145+
TORCH_INTERNAL_ASSERT(ignore_index_mask.defined());
146+
wsum = wsum * ignore_index_mask;
147+
}
139148
total_weight = wsum.sum();
140149
}
141150

@@ -244,6 +253,7 @@ at::Tensor nll_loss_backward_plumbing(
244253

245254
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
246255
m.impl("nll_loss_forward", nll_loss_forward_decomposition);
256+
m.impl("nll_loss2d_forward", nll_loss_forward_decomposition);
247257
m.impl("nll_loss_backward", nll_loss_backward_plumbing);
248258
VMAP_SUPPORT("mse_loss", mse_loss_batch_rule);
249259
VMAP_SUPPORT("mse_loss_backward", mse_loss_backward_batch_rule);

test/test_vmap.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3128,7 +3128,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
31283128
xfail('linalg.multi_dot'),
31293129
xfail('nanmean'),
31303130
xfail('nn.functional.layer_norm'),
3131-
xfail('nn.functional.nll_loss'),
31323131
xfail('vstack'),
31333132
xfail('block_diag'),
31343133
xfail('nn.functional.dropout'),

0 commit comments

Comments
 (0)