Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit dc82a73

Browse files
authored
add aminmax batching rule (#180)
* add aminmax batching rule * special case for cuda * update code and add comment
1 parent 17f407c commit dc82a73

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

functorch/csrc/BatchRulesReduceOps.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,38 @@ std::tuple<Tensor,optional<int64_t>> _log_softmax_backward_batch_rule(
255255
return std::make_tuple(at::_log_softmax_backward_data(grad_output_, output_, dim, input_dtype), 0);
256256
}
257257

258+
// aminmax has divergent behavior for 0-d tenosrs.
259+
// reference: https://github.com/pytorch/pytorch/issues/64008
260+
// TODO: Once the divergent behavior for 0-d scalar is fixed, we should use REDUCTION_BOXED_ARGS
261+
std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>> aminmax_batching_rule(
262+
const Tensor &self, optional<int64_t> self_bdim, optional<int64_t> dim, bool keep_dim)
263+
{
264+
auto self_ = moveBatchDimToFront(self, self_bdim);
265+
auto logical_rank = rankWithoutBatchDim(self_, self_bdim);
266+
if (logical_rank == 0) {
267+
self_ = self_.unsqueeze(-1);
268+
}
269+
270+
if (dim.has_value()) {
271+
dim = maybe_wrap_dim(dim.value(), logical_rank) + 1;
272+
} else {
273+
// flatten the input except for batch-dim
274+
auto bsize = self.size(0);
275+
self_ = self.view({bsize, -1});
276+
dim = 1;
277+
}
278+
279+
Tensor min, max;
280+
std::tie(min, max) = at::aminmax(self_, dim, keep_dim);
281+
282+
if (logical_rank == 0 && self_.device().is_cuda()) {
283+
// behaviour diverges between cpu and cuda
284+
min = min.squeeze(-1);
285+
max = max.squeeze(-1);
286+
}
287+
return std::make_tuple(min, 0, max, 0);
288+
}
289+
258290
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
259291
REDUCTION_BOXED(_fft_r2c);
260292
REDUCTION_BOXED(_fft_c2r);
@@ -306,6 +338,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
306338
REDUCTION_BOXED(var_mean.correction);
307339
REDUCTION_BOXED(_log_softmax);
308340
REDUCTION_BOXED_ARGS(rot90, 2);
341+
VMAP_SUPPORT("aminmax", aminmax_batching_rule);
309342

310343
VMAP_SUPPORT("_log_softmax_backward_data", _log_softmax_backward_batch_rule);
311344
VMAP_SUPPORT("_softmax_backward_data", _softmax_backward_batch_rule);

test/test_vmap.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3023,7 +3023,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
30233023
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
30243024
@skipOps('TestVmapOperatorsOpInfo', 'test_op_has_batch_rule', {
30253025
# xfail('__getitem__'),
3026-
xfail('aminmax'),
30273026
xfail('broadcast_to'),
30283027
xfail('cdist'),
30293028
xfail('complex'),

0 commit comments

Comments
 (0)