Skip to content

Commit ff4c4a4

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Add ops: max.unary_out & min.unary_out
Differential Revision: D64986580
1 parent 8f9fb7e commit ff4c4a4

File tree

8 files changed

+232
-8
lines changed

8 files changed

+232
-8
lines changed

kernels/aten/functions.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,12 +249,16 @@
249249

250250
- op: max.unary_out
251251

252+
- op: max.unary_out
253+
252254
- op: maximum.out
253255

254256
- op: mean.out
255257

256258
- op: min.dim_min
257259

260+
- op: min.unary_out
261+
258262
- op: minimum.out
259263

260264
- op: mm.out

kernels/portable/cpu/op_max.cpp

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,23 @@
99
#include <cmath>
1010
#include <tuple>
1111

12-
#include <executorch/kernels/portable/cpu/util/index_util.h>
1312
#include <executorch/kernels/portable/cpu/util/reduce_util.h>
13+
#include <executorch/kernels/portable/cpu/util/math_util.h>
1414
#include <executorch/runtime/kernel/kernel_includes.h>
1515
#include <executorch/runtime/platform/assert.h>
1616

1717
namespace torch {
1818
namespace executor {
1919
namespace native {
20+
namespace {
21+
22+
template <typename CTYPE>
23+
constexpr CTYPE lower_bound() {
24+
using lim = std::numeric_limits<CTYPE>;
25+
return lim::has_infinity ? -lim::infinity() : lim::lowest();
26+
}
27+
28+
} // namespace
2029

2130
using ScalarType = exec_aten::ScalarType;
2231
using SizesType = exec_aten::SizesType;
@@ -94,6 +103,41 @@ std::tuple<Tensor&, Tensor&> max_out(
94103
return {max, max_indices};
95104
}
96105

106+
Tensor& max_unary_out(
107+
KernelRuntimeContext& ctx,
108+
const Tensor& in,
109+
Tensor& out) {
110+
(void)ctx;
111+
112+
ET_KERNEL_CHECK(
113+
ctx, resize_tensor(out, {}) == Error::Ok, InvalidArgument, out);
114+
115+
ET_KERNEL_CHECK(
116+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
117+
118+
ScalarType in_type = in.scalar_type();
119+
ScalarType out_type = out.scalar_type();
120+
121+
ET_KERNEL_CHECK(
122+
ctx, canCast(in_type, out_type), InvalidArgument, out);
123+
124+
constexpr auto name = "max.unary_out";
125+
126+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
127+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
128+
const auto data_in = in.const_data_ptr<CTYPE_IN>();
129+
auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
130+
data_out[0] = lower_bound<CTYPE_OUT>();
131+
for (auto i = 0; i < in.numel(); ++i) {
132+
data_out[0] = utils::max_override(
133+
static_cast<CTYPE_OUT>(data_in[i]), data_out[0]);
134+
}
135+
});
136+
});
137+
138+
return out;
139+
}
140+
97141
} // namespace native
98142
} // namespace executor
99143
} // namespace torch

kernels/portable/cpu/op_min.cpp

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,23 @@
99
#include <cmath>
1010
#include <tuple>
1111

12-
#include <executorch/kernels/portable/cpu/util/index_util.h>
1312
#include <executorch/kernels/portable/cpu/util/reduce_util.h>
13+
#include <executorch/kernels/portable/cpu/util/math_util.h>
1414
#include <executorch/runtime/kernel/kernel_includes.h>
1515
#include <executorch/runtime/platform/assert.h>
1616

1717
namespace torch {
1818
namespace executor {
1919
namespace native {
20+
namespace {
21+
22+
template <typename CTYPE>
23+
constexpr CTYPE upper_bound() {
24+
using lim = std::numeric_limits<CTYPE>;
25+
return lim::has_infinity ? lim::infinity() : lim::max();
26+
}
27+
28+
} // namespace
2029

2130
using ScalarType = exec_aten::ScalarType;
2231
using SizesType = exec_aten::SizesType;
@@ -94,6 +103,41 @@ std::tuple<Tensor&, Tensor&> min_out(
94103
return {min, min_indices};
95104
}
96105

106+
Tensor& min_unary_out(
107+
KernelRuntimeContext& ctx,
108+
const Tensor& in,
109+
Tensor& out) {
110+
(void)ctx;
111+
112+
ET_KERNEL_CHECK(
113+
ctx, resize_tensor(out, {}) == Error::Ok, InvalidArgument, out);
114+
115+
ET_KERNEL_CHECK(
116+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
117+
118+
ScalarType in_type = in.scalar_type();
119+
ScalarType out_type = out.scalar_type();
120+
121+
ET_KERNEL_CHECK(
122+
ctx, canCast(in_type, out_type), InvalidArgument, out);
123+
124+
constexpr auto name = "min.unary_out";
125+
126+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
127+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
128+
const auto data_in = in.const_data_ptr<CTYPE_IN>();
129+
auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
130+
data_out[0] = upper_bound<CTYPE_OUT>();
131+
for (auto i = 0; i < in.numel(); ++i) {
132+
data_out[0] = utils::min_override(
133+
static_cast<CTYPE_OUT>(data_in[i]), data_out[0]);
134+
}
135+
});
136+
});
137+
138+
return out;
139+
}
140+
97141
} // namespace native
98142
} // namespace executor
99143
} // namespace torch

kernels/portable/cpu/util/math_util.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,10 @@ INT_T max_override(INT_T a, INT_T b) {
9696

9797
template <
9898
typename T,
99-
typename std::enable_if<std::is_same<T, exec_aten::Half>::value, bool>::
100-
type = true>
99+
typename std::enable_if<
100+
std::is_same<T, exec_aten::Half>::value ||
101+
std::is_same<T, exec_aten::BFloat16>::value,
102+
bool>::type = true>
101103
T min_override(T a, T b) {
102104
const auto float_a = static_cast<float>(a);
103105
if (std::isnan(float_a)) {
@@ -116,8 +118,10 @@ T min_override(T a, T b) {
116118

117119
template <
118120
typename T,
119-
typename std::enable_if<std::is_same<T, exec_aten::Half>::value, bool>::
120-
type = true>
121+
typename std::enable_if<
122+
std::is_same<T, exec_aten::Half>::value ||
123+
std::is_same<T, exec_aten::BFloat16>::value,
124+
bool>::type = true>
121125
T max_override(T a, T b) {
122126
const auto float_a = static_cast<float>(a);
123127
if (std::isnan(float_a)) {

kernels/portable/functions.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,11 @@
552552
- arg_meta: null
553553
kernel_name: torch::executor::max_out
554554

555+
- op: max.unary_out
556+
kernels:
557+
- arg_meta: null
558+
kernel_name: torch::executor::max_unary_out
559+
555560
- op: maximum.out
556561
kernels:
557562
- arg_meta: null
@@ -572,6 +577,11 @@
572577
- arg_meta: null
573578
kernel_name: torch::executor::min_out
574579

580+
- op: min.unary_out
581+
kernels:
582+
- arg_meta: null
583+
kernel_name: torch::executor::min_unary_out
584+
575585
- op: minimum.out
576586
kernels:
577587
- arg_meta: null

kernels/test/op_max_test.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,65 @@ void OpMaxOutTest::test_max_out_dtype<ScalarType::Bool>() {
222222
// clang-format on
223223
}
224224

225+
class OpMaxUnaryOutTest : public OperatorTest {
226+
protected:
227+
Tensor& op_max_unary_out(
228+
const Tensor& self,
229+
Tensor& out) {
230+
return torch::executor::aten::max_outf(
231+
context_, self, out);
232+
}
233+
234+
template <ScalarType IN_DTYPE>
235+
void test_max_unary_out_dtype() {
236+
TensorFactory<IN_DTYPE> tf_in;
237+
TensorFactory<ScalarType::Float> tf_out;
238+
Tensor input = tf_in.make({2, 3}, {0, 1, 2, 4, 4, 2});
239+
Tensor out = tf_out.zeros({});
240+
Tensor expected = tf_out.make({}, {4});
241+
op_max_unary_out(input, out);
242+
EXPECT_TENSOR_CLOSE(out, expected);
243+
}
244+
245+
template <typename CTYPE, ScalarType IN_DTYPE>
246+
void test_max_unary_out_empty_integer() {
247+
TensorFactory<IN_DTYPE> tf_in;
248+
Tensor input = tf_in.make({2, 0}, {});
249+
Tensor out = tf_in.zeros({});
250+
Tensor expected = tf_in.make({}, {std::numeric_limits<CTYPE>::lowest()});
251+
op_max_unary_out(input, out);
252+
EXPECT_TENSOR_CLOSE(out, expected);
253+
}
254+
255+
template <typename CTYPE, ScalarType IN_DTYPE>
256+
void test_max_unary_out_empty_floating() {
257+
TensorFactory<IN_DTYPE> tf_in;
258+
Tensor input = tf_in.make({2, 0}, {});
259+
Tensor out = tf_in.zeros({});
260+
Tensor expected = tf_in.make({}, {-INFINITY});
261+
op_max_unary_out(input, out);
262+
EXPECT_TENSOR_CLOSE(out, expected);
263+
}
264+
};
265+
266+
TEST_F(OpMaxUnaryOutTest, AllRealHBF16InputFloatOutputPasses) {
267+
#define TEST_ENTRY(ctype, dtype) test_max_unary_out_dtype<ScalarType::dtype>();
268+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
269+
#undef TEST_ENTRY
270+
}
271+
272+
TEST_F(OpMaxUnaryOutTest, EmptyIntegerInput) {
273+
#define TEST_ENTRY(ctype, dtype) test_max_unary_out_empty_integer<ctype, ScalarType::dtype>();
274+
ET_FORALL_INT_TYPES(TEST_ENTRY);
275+
#undef TEST_ENTRY
276+
}
277+
278+
TEST_F(OpMaxUnaryOutTest, EmptyFloatingInput) {
279+
#define TEST_ENTRY(ctype, dtype) test_max_unary_out_empty_floating<ctype, ScalarType::dtype>();
280+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
281+
#undef TEST_ENTRY
282+
}
283+
225284
TEST_F(OpMaxOutTest, MismatchedDimensionsDies) {
226285
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
227286
GTEST_SKIP() << "ATen kernel test fails";

kernels/test/op_min_test.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,65 @@ EXPECT_TENSOR_EQ(min_indices, tf_long.make(
218218
// clang-format on
219219
}
220220

221+
class OpMinUnaryOutTest : public OperatorTest {
222+
protected:
223+
Tensor& op_min_unary_out(
224+
const Tensor& self,
225+
Tensor& out) {
226+
return torch::executor::aten::min_outf(
227+
context_, self, out);
228+
}
229+
230+
template <ScalarType IN_DTYPE>
231+
void test_min_unary_out_dtype() {
232+
TensorFactory<IN_DTYPE> tf_in;
233+
TensorFactory<ScalarType::Float> tf_out;
234+
Tensor input = tf_in.make({2, 3}, {7, 1, 3, 4, 4, 2});
235+
Tensor out = tf_out.zeros({});
236+
Tensor expected = tf_out.make({}, {1});
237+
op_min_unary_out(input, out);
238+
EXPECT_TENSOR_CLOSE(out, expected);
239+
}
240+
241+
template <typename CTYPE, ScalarType IN_DTYPE>
242+
void test_min_unary_out_empty_integer() {
243+
TensorFactory<IN_DTYPE> tf_in;
244+
Tensor input = tf_in.make({2, 0}, {});
245+
Tensor out = tf_in.zeros({});
246+
Tensor expected = tf_in.make({}, {std::numeric_limits<CTYPE>::max()});
247+
op_min_unary_out(input, out);
248+
EXPECT_TENSOR_CLOSE(out, expected);
249+
}
250+
251+
template <typename CTYPE, ScalarType IN_DTYPE>
252+
void test_min_unary_out_empty_floating() {
253+
TensorFactory<IN_DTYPE> tf_in;
254+
Tensor input = tf_in.make({2, 0}, {});
255+
Tensor out = tf_in.zeros({});
256+
Tensor expected = tf_in.make({}, {INFINITY});
257+
op_min_unary_out(input, out);
258+
EXPECT_TENSOR_CLOSE(out, expected);
259+
}
260+
};
261+
262+
TEST_F(OpMinUnaryOutTest, AllRealHBF16InputFloatOutputPasses) {
263+
#define TEST_ENTRY(ctype, dtype) test_min_unary_out_dtype<ScalarType::dtype>();
264+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
265+
#undef TEST_ENTRY
266+
}
267+
268+
TEST_F(OpMinUnaryOutTest, EmptyIntegerInput) {
269+
#define TEST_ENTRY(ctype, dtype) test_min_unary_out_empty_integer<ctype, ScalarType::dtype>();
270+
ET_FORALL_INT_TYPES(TEST_ENTRY);
271+
#undef TEST_ENTRY
272+
}
273+
274+
TEST_F(OpMinUnaryOutTest, EmptyFloatingInput) {
275+
#define TEST_ENTRY(ctype, dtype) test_min_unary_out_empty_floating<ctype, ScalarType::dtype>();
276+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
277+
#undef TEST_ENTRY
278+
}
279+
221280
TEST_F(OpMinOutTest, MismatchedDimensionsDies) {
222281
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
223282
GTEST_SKIP() << "ATen kernel test fails";

shim/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,7 @@ ATEN_OPS = (
787787
deps = [
788788
"//executorch/runtime/core/exec_aten/util:scalar_type_util",
789789
"//executorch/runtime/core/exec_aten/util:tensor_util",
790-
"//executorch/kernels/portable/cpu/util:index_util",
790+
"//executorch/kernels/portable/cpu/util:math_util",
791791
"//executorch/kernels/portable/cpu/util:reduce_util",
792792
],
793793
),
@@ -821,7 +821,7 @@ ATEN_OPS = (
821821
deps = [
822822
"//executorch/runtime/core/exec_aten/util:scalar_type_util",
823823
"//executorch/runtime/core/exec_aten/util:tensor_util",
824-
"//executorch/kernels/portable/cpu/util:index_util",
824+
"//executorch/kernels/portable/cpu/util:math_util",
825825
"//executorch/kernels/portable/cpu/util:reduce_util",
826826
],
827827
),

0 commit comments

Comments
 (0)