@@ -255,6 +255,38 @@ std::tuple<Tensor,optional<int64_t>> _log_softmax_backward_batch_rule(
255
255
return std::make_tuple (at::_log_softmax_backward_data (grad_output_, output_, dim, input_dtype), 0 );
256
256
}
257
257
258
+ // aminmax has divergent behavior for 0-d tenosrs.
259
+ // reference: https://github.com/pytorch/pytorch/issues/64008
260
+ // TODO: Once the divergent behavior for 0-d scalar is fixed, we should use REDUCTION_BOXED_ARGS
261
+ std::tuple<Tensor, optional<int64_t >, Tensor, optional<int64_t >> aminmax_batching_rule (
262
+ const Tensor &self, optional<int64_t > self_bdim, optional<int64_t > dim, bool keep_dim)
263
+ {
264
+ auto self_ = moveBatchDimToFront (self, self_bdim);
265
+ auto logical_rank = rankWithoutBatchDim (self_, self_bdim);
266
+ if (logical_rank == 0 ) {
267
+ self_ = self_.unsqueeze (-1 );
268
+ }
269
+
270
+ if (dim.has_value ()) {
271
+ dim = maybe_wrap_dim (dim.value (), logical_rank) + 1 ;
272
+ } else {
273
+ // flatten the input except for batch-dim
274
+ auto bsize = self.size (0 );
275
+ self_ = self.view ({bsize, -1 });
276
+ dim = 1 ;
277
+ }
278
+
279
+ Tensor min, max;
280
+ std::tie (min, max) = at::aminmax (self_, dim, keep_dim);
281
+
282
+ if (logical_rank == 0 && self_.device ().is_cuda ()) {
283
+ // behaviour diverges between cpu and cuda
284
+ min = min.squeeze (-1 );
285
+ max = max.squeeze (-1 );
286
+ }
287
+ return std::make_tuple (min, 0 , max, 0 );
288
+ }
289
+
258
290
TORCH_LIBRARY_IMPL (aten, FT_BATCHED_KEY, m) {
259
291
REDUCTION_BOXED (_fft_r2c);
260
292
REDUCTION_BOXED (_fft_c2r);
@@ -306,6 +338,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
306
338
REDUCTION_BOXED (var_mean.correction );
307
339
REDUCTION_BOXED (_log_softmax);
308
340
REDUCTION_BOXED_ARGS (rot90, 2 );
341
+ VMAP_SUPPORT (" aminmax" , aminmax_batching_rule);
309
342
310
343
VMAP_SUPPORT (" _log_softmax_backward_data" , _log_softmax_backward_batch_rule);
311
344
VMAP_SUPPORT (" _softmax_backward_data" , _softmax_backward_batch_rule);
0 commit comments