Skip to content

Commit 7c6ecd6

Browse files
committed
Update
[ghstack-poisoned]
2 parents 92649e7 + 4bc2029 commit 7c6ecd6

19 files changed

+169
-125
lines changed

kernels/portable/cpu/op_fill.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Tensor& fill_scalar_out(
4242
out,
4343
"Failed to resize output tensor.");
4444

45-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "fill.Scalar_out", CTYPE_A, [&] {
45+
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "fill.Scalar_out", CTYPE_A, [&] {
4646
CTYPE_A b_casted;
4747
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "fill.Scalar_out", CTYPE_B, [&] {
4848
CTYPE_B b_val;
@@ -87,14 +87,14 @@ Tensor& fill_tensor_out(
8787
out,
8888
"Failed to resize output tensor.");
8989

90-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "fill.Tensor_out", CTYPE_A, [&] {
90+
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "fill.Tensor_out", CTYPE_A, [&] {
9191
CTYPE_A b_casted;
92-
ET_SWITCH_REAL_TYPES_AND(
93-
Bool, b_type, ctx, "fill.Tensor_out", CTYPE_B, [&] {
94-
CTYPE_B b_val;
95-
extract_scalar_tensor(b, &b_val);
96-
b_casted = static_cast<CTYPE_A>(b_val);
97-
});
92+
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "fill.Tensor_out", CTYPE_B, [&] {
93+
CTYPE_B b_val;
94+
ET_DCHECK_MSG(
95+
extract_scalar_tensor(b, &b_val), "extract_scalar_tensor failed!");
96+
b_casted = static_cast<CTYPE_A>(b_val);
97+
});
9898

9999
apply_unary_map_fn(
100100
[b_casted](const CTYPE_A val_a) { return b_casted; },

kernels/portable/cpu/op_gather.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ Tensor& gather_out(
8686

8787
constexpr auto name = "gather.out";
8888

89-
ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
89+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
9090
gather_helper<CTYPE>(in, index, out, dim);
9191
});
9292

kernels/portable/cpu/op_leaky_relu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ Tensor& leaky_relu_out(
4444

4545
ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out);
4646

47-
ET_SWITCH_FLOAT_TYPES(in_type, ctx, "leaky_relu.out", CTYPE, [&]() {
47+
ET_SWITCH_FLOATHBF16_TYPES(in_type, ctx, "leaky_relu.out", CTYPE, [&]() {
4848
CTYPE negative_slope_casted;
4949
ET_SWITCH_SCALAR_OBJ_TYPES(
5050
sc_type, ctx, "leaky_relu.out", CTYPE_MIN, [&]() {

kernels/portable/cpu/op_log_softmax.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Tensor& log_softmax_out(
4242
// Adjust for negative dim
4343
dim = dim < 0 ? dim + nonzero_dim(in) : dim;
4444

45-
ET_SWITCH_FLOAT_TYPES(
45+
ET_SWITCH_FLOATHBF16_TYPES(
4646
in.scalar_type(), ctx, "_log_softmax.out", CTYPE, [&]() {
4747
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
4848
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();

kernels/portable/cpu/op_logical_not.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ logical_not_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
3333

3434
ET_KERNEL_CHECK(ctx, tensors_have_same_shape(in, out), InvalidArgument, out);
3535

36-
ET_SWITCH_REAL_TYPES_AND(
37-
Bool, in.scalar_type(), ctx, "logical_not.out", CTYPE_IN, [&] {
38-
ET_SWITCH_REAL_TYPES_AND(
39-
Bool, out.scalar_type(), ctx, "logical_not.out", CTYPE_OUT, [&] {
36+
ET_SWITCH_REALHBBF16_TYPES(
37+
in.scalar_type(), ctx, "logical_not.out", CTYPE_IN, [&] {
38+
ET_SWITCH_REALHBBF16_TYPES(
39+
out.scalar_type(), ctx, "logical_not.out", CTYPE_OUT, [&] {
4040
apply_unary_map_fn(
4141
[](const CTYPE_IN val_in) {
4242
return static_cast<CTYPE_OUT>(!static_cast<bool>(val_in));

kernels/portable/cpu/op_masked_fill.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ Tensor& masked_fill_scalar_out(
4242
ET_KERNEL_CHECK(
4343
ctx, tensors_have_same_dim_order(in, mask, out), InvalidArgument, out);
4444

45-
ET_SWITCH_REAL_TYPES_AND(
46-
Bool, in_type, ctx, "masked_fill.Scalar_out", CTYPE, [&]() {
45+
ET_SWITCH_REALHBBF16_TYPES(
46+
in_type, ctx, "masked_fill.Scalar_out", CTYPE, [&]() {
4747
ET_SWITCH_REAL_TYPES_AND(
4848
Bool, val_type, ctx, "masked_fill.Scalar_out", CTYPE_VAL, [&]() {
4949
CTYPE_VAL value_v;

kernels/portable/cpu/op_max_pool2d_with_indices.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ std::tuple<Tensor&, Tensor&> max_pool2d_with_indices_out(
7070
ret_val);
7171

7272
ScalarType in_type = in.scalar_type();
73-
ET_SWITCH_REAL_TYPES(
73+
ET_SWITCH_REALHBF16_TYPES(
7474
in_type, ctx, "max_pool2d_with_indices.out", CTYPE, [&]() {
7575
apply_kernel_2d_reduce_then_map_fn<CTYPE>(
7676
[](const CTYPE in_val,

kernels/portable/cpu/op_mean.cpp

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,23 +44,24 @@ Tensor& mean_dim_out(
4444
InvalidArgument,
4545
out);
4646

47-
ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] {
48-
ET_SWITCH_FLOATH_TYPES(out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] {
49-
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
50-
const size_t num = get_reduced_dim_product(in, dim_list);
51-
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
52-
CTYPE_OUT sum = 0;
53-
if (in.numel() > 0) {
54-
sum = map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
55-
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
56-
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
57-
in,
58-
dim_list,
59-
out_ix);
60-
}
61-
out_data[out_ix] = sum / static_cast<float>(num);
62-
}
63-
});
47+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] {
48+
ET_SWITCH_FLOATHBF16_TYPES(
49+
out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] {
50+
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
51+
const size_t num = get_reduced_dim_product(in, dim_list);
52+
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
53+
CTYPE_OUT sum = 0;
54+
if (in.numel() > 0) {
55+
sum = map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
56+
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
57+
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
58+
in,
59+
dim_list,
60+
out_ix);
61+
}
62+
out_data[out_ix] = sum / static_cast<float>(num);
63+
}
64+
});
6465
});
6566

6667
return out;

kernels/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ set(all_test_sources
139139
"op_fmod_test.cpp"
140140
"op_full_like_test.cpp"
141141
"op_full_test.cpp"
142+
"op_gather_test.cpp"
142143
"op_ge_test.cpp"
143144
"op_gelu_test.cpp"
144145
"op_glu_test.cpp"

kernels/test/op_fill_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,15 @@ class OpFillTest : public OperatorTest {
9292
TEST_FILL_OUT(test_fill_scalar_out, DTYPE); \
9393
}
9494

95-
ET_FORALL_REAL_TYPES_AND(Bool, GENERATE_SCALAR_INPUT_SUPPORT_TEST)
95+
ET_FORALL_REALHBBF16_TYPES(GENERATE_SCALAR_INPUT_SUPPORT_TEST)
9696

9797
// Create input support tests for tensor variant.
9898
#define GENERATE_TENSOR_INPUT_SUPPORT_TEST(_, DTYPE) \
9999
TEST_F(OpFillTest, DTYPE##TensorInputSupport) { \
100100
TEST_FILL_OUT(test_fill_tensor_out, DTYPE); \
101101
}
102102

103-
ET_FORALL_REAL_TYPES_AND(Bool, GENERATE_TENSOR_INPUT_SUPPORT_TEST)
103+
ET_FORALL_REALHBBF16_TYPES(GENERATE_TENSOR_INPUT_SUPPORT_TEST)
104104

105105
TEST_F(OpFillTest, MismatchedOtherPropertiesDies) {
106106
TensorFactory<ScalarType::Int> tf;

0 commit comments

Comments
 (0)