Skip to content

Commit f82be32

Browse files
authored
Support Half/BFloat16 in relu (#7858)
Partial fix for #7748.
1 parent 74d4fb6 commit f82be32

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

kernels/portable/cpu/op_relu.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,16 @@ Tensor& relu_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
3333
ET_KERNEL_CHECK(
3434
ctx, tensors_have_same_shape_and_dtype(in, out), InvalidArgument, out);
3535

36-
ET_KERNEL_CHECK(ctx, tensor_is_real_type(out), InvalidArgument, out);
36+
ET_KERNEL_CHECK(
37+
ctx,
38+
executorch::runtime::tensor_is_realhbf16_type(out),
39+
InvalidArgument,
40+
out);
3741

3842
ET_KERNEL_CHECK(
3943
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
4044

41-
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "relu.out", CTYPE, [&]() {
45+
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "relu.out", CTYPE, [&]() {
4246
apply_unary_map_fn(
4347
[](const CTYPE val_in) {
4448
return (std::isnan(val_in) || val_in >= CTYPE(0)) ? val_in : CTYPE(0);

kernels/test/op_relu_test.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,14 @@ TEST_F(OpReluTest, DoubleTensors) {
8282
test_relu_execution_floats<ScalarType::Double>();
8383
}
8484

85+
TEST_F(OpReluTest, HalfTensors) {
86+
test_relu_execution_floats<ScalarType::Half>();
87+
}
88+
89+
TEST_F(OpReluTest, BFloat16Tensors) {
90+
test_relu_execution_floats<ScalarType::BFloat16>();
91+
}
92+
8593
TEST_F(OpReluTest, ByteTensors) {
8694
TensorFactory<ScalarType::Byte> tf;
8795

0 commit comments

Comments
 (0)