Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit d6ef390

Browse files
authored
Implemented nll loss through decomposition (#208)
* WIP on nll_loss br implementation * WIP on nll_loss (2) * Updated tests * Removed commented code * Updated decomposition * Fixed total_weight thanks to Richard's suggestion * Removed nll_loss_nd and renamed nll_loss_forward_plumbing -> nll_loss_forward_decomposition
1 parent ab4a32f commit d6ef390

File tree

5 files changed

+76
-110
lines changed

5 files changed

+76
-110
lines changed

functorch/csrc/BatchRulesHelper.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,16 @@ inline int64_t get_bdim_size2(
393393
TORCH_INTERNAL_ASSERT(false);
394394
}
395395

396+
// [start, start + 1, ..., stop - 1]
397+
inline VmapDimVector range(int64_t start, int64_t stop) {
398+
TORCH_INTERNAL_ASSERT(stop >= start);
399+
VmapDimVector dims;
400+
dims.reserve(stop - start);
401+
for (int64_t i = start; i < stop; i++) {
402+
dims.emplace_back(i);
403+
}
404+
return dims;
405+
}
396406

397407
}}
398408

functorch/csrc/BatchRulesLoss.cpp

Lines changed: 66 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -67,108 +67,79 @@ mse_loss_backward_batch_rule(
6767
return std::make_tuple(result, 0);
6868
};
6969

70-
std::tuple<at::Tensor,optional<int64_t>,at::Tensor,optional<int64_t>>
71-
nll_loss_forward_self_target_batch_rule(
72-
const at::Tensor & self, optional<int64_t> self_bdim,
73-
const at::Tensor & target, optional<int64_t> target_bdim,
74-
int64_t reduction) {
75-
TORCH_INTERNAL_ASSERT(self.dim() == 3 && target.dim() == 2);
76-
77-
if (reduction == Reduction::None) {
78-
int64_t batch_size = self.size(*self_bdim);
79-
auto self_ = reshape_dim_into(*self_bdim, 0, self);
80-
auto target_ = reshape_dim_into(*target_bdim, 0, target);
81-
auto result = at::nll_loss_forward(self_, target_, nullopt, reduction, -100);
82-
return std::make_tuple(
83-
reshape_dim_outof(0, batch_size, std::get<0>(result)), 0,
84-
std::get<1>(result), nullopt
85-
);
86-
} else if (reduction == Reduction::Sum) {
87-
int64_t batch_size = self.size(*self_bdim);
88-
auto self_ = reshape_dim_into(*self_bdim, 0, self);
89-
auto target_ = reshape_dim_into(*target_bdim, 0, target);
90-
auto res = at::nll_loss_forward(self_, target_, nullopt, Reduction::None, -100);
91-
auto output = std::get<0>(res);
92-
output = reshape_dim_outof(0, batch_size, output);
93-
auto total_weight = self_.new_full({}, output.size(-1));
94-
return std::make_tuple(
95-
output.sum(-1), 0,
96-
// NB: total_weight = 0 after Reduction::None
97-
total_weight, nullopt
98-
);
99-
} else if (reduction == Reduction::Mean) {
100-
int64_t batch_size = self.size(*self_bdim);
101-
auto self_ = reshape_dim_into(*self_bdim, 0, self);
102-
auto target_ = reshape_dim_into(*target_bdim, 0, target);
103-
auto res = at::nll_loss_forward(self_, target_, nullopt, Reduction::None, -100);
104-
auto output = std::get<0>(res);
105-
output = reshape_dim_outof(0, batch_size, output);
106-
auto total_weight = self_.new_full({}, output.size(-1));
107-
return std::make_tuple(
108-
output.mean(-1), 0,
109-
// NB: total_weight = 0 after Reduction::None
110-
total_weight, nullopt
111-
);
112-
}
113-
TORCH_INTERNAL_ASSERT(false);
114-
}
115-
116-
std::tuple<at::Tensor,at::Tensor> nll_loss_forward_plumbing(
117-
const at::Tensor & self,
118-
const at::Tensor & target,
119-
const c10::optional<at::Tensor> & weight,
70+
std::tuple<Tensor, Tensor> nll_loss_forward_decomposition(
71+
const Tensor & self,
72+
const Tensor & target,
73+
const c10::optional<Tensor> & weight,
12074
int64_t reduction, int64_t ignore_index) {
121-
auto maybe_layer = maybeCurrentDynamicLayer();
122-
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
123-
int64_t cur_level = maybe_layer->layerId();
124-
125-
Tensor self_value;
126-
optional<int64_t> self_bdim;
127-
std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level);
12875

129-
Tensor target_value;
130-
optional<int64_t> target_bdim;
131-
std::tie(target_value, target_bdim) = unwrapTensorAtLevel(target, cur_level);
132-
133-
optional<Tensor> weight_value;
134-
optional<int64_t> weight_bdim;
135-
if (weight) {
136-
std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(*weight, cur_level);
76+
bool has_ignore_index = ignore_index >= 0;
77+
if (has_ignore_index) {
78+
// fallback
79+
if (target.dim() > 1) {
80+
static auto op = c10::Dispatcher::singleton()
81+
.findSchemaOrThrow("aten::nll_loss_nd", "");
82+
return slow_fallback<Tensor, Tensor>(op, {self, target, weight, reduction, ignore_index});
83+
} else {
84+
static auto op = c10::Dispatcher::singleton()
85+
.findSchemaOrThrow("aten::nll_loss_forward", "");
86+
return slow_fallback<Tensor, Tensor>(op, {self, target, weight, reduction, ignore_index});
87+
}
13788
}
89+
// self can be [N, C, ...] or [C]
90+
// target can be [N, ...] or []
13891

139-
if (self_bdim && target_bdim && (!weight || !weight->defined()) && ignore_index < 0) {
140-
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
141-
auto results = nll_loss_forward_self_target_batch_rule(
142-
self_value, self_bdim, target_value, target_bdim, reduction);
143-
return std::make_tuple(
144-
makeBatched(std::get<0>(results), std::get<1>(results), cur_level),
145-
makeBatched(std::get<2>(results), std::get<3>(results), cur_level)
146-
);
92+
int64_t channel_dim = 1;
93+
if (self.dim() < 2) {
94+
channel_dim = 0;
14795
}
148-
149-
if ((!weight || !weight->defined()) && ignore_index < 0) {
150-
// Decomposition: gather to get unreduced loss. 1 is for the C dim, that's always 1.
151-
// gather can handle arbitrary strides so it's a good candidate for a decomposition.
152-
auto target_ = target.unsqueeze(1);
153-
auto result = at::gather(self, 1, target_).squeeze(1);
154-
auto total_weight = at::full(
155-
{}, result.numel(), self.scalar_type(),
156-
self.layout(), self.device(), nullopt);
157-
158-
// Apply the reduction
159-
switch (reduction) {
160-
case Reduction::None:
161-
return std::make_tuple(-result, total_weight);
162-
case Reduction::Sum:
163-
return std::make_tuple(-result.sum(), total_weight);
164-
case Reduction::Mean:
165-
return std::make_tuple(-result.mean(), total_weight);
96+
auto self_ = self;
97+
Tensor weight_;
98+
99+
if (weight && weight->defined()) {
100+
// Here is a specific case with reduction mean and non-batched tensors
101+
// https://github.com/pytorch/pytorch/issues/61309
102+
// In this case weight is cancelled: w * x[t] / w -> x[t]
103+
if (!(reduction == Reduction::Mean && self_.dim() < 2)) {
104+
// reshape weights to [1, C, 1, ..., 1]
105+
auto shape = weight->sizes();
106+
VmapDimVector new_shape(self_.dim(), 1);
107+
new_shape[channel_dim] = shape[0];
108+
weight_ = weight->reshape(new_shape);
109+
self_ = self_ * weight_;
166110
}
167111
}
112+
auto target_ = target.unsqueeze(channel_dim);
113+
// target can be [N, 1, ...] or [1]
114+
115+
auto result = -at::gather(self_, channel_dim, target_).squeeze(channel_dim);
116+
auto total_weight = at::full(
117+
{}, result.numel(), self_.scalar_type(),
118+
self_.layout(), self_.device(), nullopt);
119+
120+
// Apply the reduction
121+
if (result.dim() > 0) {
122+
if (reduction == Reduction::Sum) {
123+
result = result.sum();
124+
} else if (reduction == Reduction::Mean) {
125+
if (!weight || !weight->defined()) {
126+
result = result.mean();
127+
} else {
128+
TORCH_INTERNAL_ASSERT(weight_.defined());
129+
weight_ = weight_.expand(self_.sizes());
130+
auto wsum = at::gather(weight_, channel_dim, target_).squeeze(channel_dim);
131+
wsum = wsum.sum();
132+
result = result.sum() / wsum;
133+
total_weight = wsum;
134+
}
135+
}
136+
} else if (reduction == Reduction::Mean && weight && weight->defined()) {
137+
// here weight is [C] and target is [1]
138+
auto wsum = at::gather(*weight, channel_dim, target_).squeeze(channel_dim);
139+
total_weight = wsum.sum();
140+
}
168141

169-
static auto op = c10::Dispatcher::singleton()
170-
.findSchemaOrThrow("aten::nll_loss_forward", "");
171-
return slow_fallback<Tensor,Tensor>(op, {self, target, weight, reduction, ignore_index});
142+
return std::make_tuple(result, total_weight);
172143
}
173144

174145
std::tuple<at::Tensor,optional<int64_t>>
@@ -272,7 +243,7 @@ at::Tensor nll_loss_backward_plumbing(
272243

273244

274245
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
275-
m.impl("nll_loss_forward", nll_loss_forward_plumbing);
246+
m.impl("nll_loss_forward", nll_loss_forward_decomposition);
276247
m.impl("nll_loss_backward", nll_loss_backward_plumbing);
277248
VMAP_SUPPORT("mse_loss", mse_loss_batch_rule);
278249
VMAP_SUPPORT("mse_loss_backward", mse_loss_backward_batch_rule);

functorch/csrc/BatchRulesReduceOps.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,6 @@
1111

1212
namespace at { namespace functorch {
1313

14-
// [start, start + 1, ..., stop - 1]
15-
static VmapDimVector range(int64_t start, int64_t stop) {
16-
TORCH_INTERNAL_ASSERT(stop >= start);
17-
VmapDimVector dims;
18-
dims.reserve(stop - start);
19-
for (int64_t i = start; i < stop; i++) {
20-
dims.emplace_back(i);
21-
}
22-
return dims;
23-
}
24-
25-
2614
bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
2715
return dim == 0 || dim == -1;
2816
}

test/test_ops.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,6 @@ def vjp_of_vjp(*args_and_cotangents):
360360
xfail('nanmean'),
361361
xfail('block_diag'),
362362
xfail('nn.functional.dropout'),
363-
xfail('nn.functional.nll_loss'),
364363
}))
365364
def test_vmapvjp(self, device, dtype, op):
366365
# These are too annoying to put into the list above
@@ -539,7 +538,6 @@ def test():
539538
xfail('vstack'),
540539
xfail('block_diag'),
541540
xfail('nn.functional.batch_norm'),
542-
xfail('nn.functional.nll_loss'),
543541
xfail('cdist'),
544542
xfail('lu_solve'),
545543
xfail('lu_unpack'),

test/test_vmap.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3015,7 +3015,6 @@ class TestVmapOperatorsOpInfo(TestCase):
30153015
xfail('fft.rfft'),
30163016
xfail('fft.rfftn'),
30173017
xfail('nn.functional.batch_norm'),
3018-
xfail('nn.functional.nll_loss'),
30193018
xfail('lu_unpack'),
30203019
xfail('nn.functional.pad', 'constant'),
30213020
})

0 commit comments

Comments
 (0)