Skip to content

Commit b2d72a4

Browse files
Revert "Don't hardcode double argument for reduction base (pytorch#166951)"
This reverts commit a74fe75. Reverted pytorch#166951 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](pytorch#166951 (comment)))
1 parent 80ec2ab commit b2d72a4

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

aten/src/ATen/native/cpu/Reduce.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,8 @@ void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) {
247247
});
248248
}
249249

250-
template <typename func_t, typename vec_func_t, typename ident_t = double>
251-
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, ident_t ident = static_cast<ident_t>(0)) {
250+
template <typename func_t, typename vec_func_t>
251+
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) {
252252
using traits = binary_function_traits<func_t>;
253253
static_assert(
254254
all_same<

aten/src/ATen/native/cpu/ReduceOpsKernel.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,13 +339,33 @@ void or_kernel_impl(TensorIterator& iter) {
339339
}
340340
}
341341

342+
template<typename scalar_t>
343+
struct MinValuesOps: public at::native::MinOps<scalar_t> {
344+
using arg_t = typename MinOps<scalar_t>::arg_t;
345+
static scalar_t project(arg_t arg) {
346+
return arg.first;
347+
}
348+
};
349+
342350
void min_values_kernel_impl(TensorIterator& iter) {
351+
// This case is special because of Vectorized<int64_t> does not
352+
// handle upper_bound<int64_t>().
353+
// See: https://github.com/pytorch/pytorch/issues/43254
354+
if (iter.dtype() == kLong || iter.dtype() == kUInt64) {
355+
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
356+
binary_kernel_reduce(
357+
iter,
358+
MinValuesOps<scalar_t>{},
359+
std::pair<scalar_t, int64_t>(upper_bound<scalar_t>(), -1));
360+
}), kLong, kUInt64);
361+
return;
362+
}
343363
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
344364
binary_kernel_reduce_vec(
345365
iter,
346366
[](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); },
347367
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return minimum(a, b); },
348-
upper_bound<scalar_t>());
368+
static_cast<double>(upper_bound<scalar_t>()));
349369
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
350370
}
351371

0 commit comments

Comments
 (0)