From c8fd7f1066f57aa1f3b52703ff9260047228a92e Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 21 Jan 2025 16:06:21 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- kernels/portable/cpu/op_leaky_relu.cpp | 2 +- kernels/test/op_leaky_relu_test.cpp | 22 ++++++++++++++-------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/kernels/portable/cpu/op_leaky_relu.cpp b/kernels/portable/cpu/op_leaky_relu.cpp index 90e91435e4c..3493c26e477 100644 --- a/kernels/portable/cpu/op_leaky_relu.cpp +++ b/kernels/portable/cpu/op_leaky_relu.cpp @@ -44,7 +44,7 @@ Tensor& leaky_relu_out( ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out); - ET_SWITCH_FLOAT_TYPES(in_type, ctx, "leaky_relu.out", CTYPE, [&]() { + ET_SWITCH_FLOATHBF16_TYPES(in_type, ctx, "leaky_relu.out", CTYPE, [&]() { CTYPE negative_slope_casted; ET_SWITCH_SCALAR_OBJ_TYPES( sc_type, ctx, "leaky_relu.out", CTYPE_MIN, [&]() { diff --git a/kernels/test/op_leaky_relu_test.cpp b/kernels/test/op_leaky_relu_test.cpp index 1c5ca68152a..514c7dc6b51 100644 --- a/kernels/test/op_leaky_relu_test.cpp +++ b/kernels/test/op_leaky_relu_test.cpp @@ -29,15 +29,21 @@ class OpLeakyReluTest : public OperatorTest { return torch::executor::aten::leaky_relu_outf( context_, in, negative_slope, out); } -}; + template + void test_leaky_relu_dtype() { + TensorFactory tf; + Tensor in = tf.ones({2, 2}); + Tensor out = tf.zeros({2, 2}); -TEST_F(OpLeakyReluTest, SanityCheck) { - TensorFactory tf; - Tensor in = tf.ones({2, 2}); - Tensor out = tf.zeros({2, 2}); + Tensor ret = op_leaky_relu_out(in, -0.01, out); - Tensor ret = op_leaky_relu_out(in, -0.01, out); + EXPECT_TENSOR_EQ(out, ret); + EXPECT_TENSOR_EQ(out, tf.ones({2, 2})); + } +}; - EXPECT_TENSOR_EQ(out, ret); - EXPECT_TENSOR_EQ(out, tf.ones({2, 2})); +TEST_F(OpLeakyReluTest, SanityCheck) { +#define TEST_ENTRY(ctype, dtype) test_leaky_relu_dtype(); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY }