Skip to content

Commit 72ee1ba

Browse files
swolchokYIWENX14
authored andcommitted
Support Half/BFloat16 in round (#7862)
Partial fix for #7748.
1 parent c5afb4d commit 72ee1ba

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

kernels/portable/cpu/op_round.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,18 @@ Tensor& round_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
4343

4444
ET_KERNEL_CHECK(
4545
ctx, tensors_have_same_shape_and_dtype(in, out), InvalidArgument, out);
46-
ET_KERNEL_CHECK(ctx, tensor_is_real_type(out), InvalidArgument, out);
46+
ET_KERNEL_CHECK(
47+
ctx,
48+
executorch::runtime::tensor_is_realhbf16_type(out),
49+
InvalidArgument,
50+
out);
4751

4852
ET_KERNEL_CHECK(
4953
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
5054

5155
auto in_scalar_type = in.scalar_type();
5256

53-
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "round.out", CTYPE, [&] {
57+
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "round.out", CTYPE, [&] {
5458
apply_unary_map_fn(
5559
[in_scalar_type](const CTYPE val_in) {
5660
if (isIntegralType(in_scalar_type, /*includeBool=*/false)) {

kernels/test/op_round_test.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,14 @@ TEST_F(OpRoundTest, DoubleTensors) {
8484
test_round_execution_floats<ScalarType::Double>();
8585
}
8686

87+
TEST_F(OpRoundTest, HalfTensors) {
88+
test_round_execution_floats<ScalarType::Half>();
89+
}
90+
91+
TEST_F(OpRoundTest, BFloat16Tensors) {
92+
test_round_execution_floats<ScalarType::BFloat16>();
93+
}
94+
8795
TEST_F(OpRoundTest, ByteTensors) {
8896
TensorFactory<ScalarType::Byte> tf;
8997

0 commit comments

Comments
 (0)