We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 466d98f commit c8fd7f1Copy full SHA for c8fd7f1
kernels/portable/cpu/op_leaky_relu.cpp
@@ -44,7 +44,7 @@ Tensor& leaky_relu_out(
44
45
ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out);
46
47
- ET_SWITCH_FLOAT_TYPES(in_type, ctx, "leaky_relu.out", CTYPE, [&]() {
+ ET_SWITCH_FLOATHBF16_TYPES(in_type, ctx, "leaky_relu.out", CTYPE, [&]() {
48
CTYPE negative_slope_casted;
49
ET_SWITCH_SCALAR_OBJ_TYPES(
50
sc_type, ctx, "leaky_relu.out", CTYPE_MIN, [&]() {
kernels/test/op_leaky_relu_test.cpp
@@ -29,15 +29,21 @@ class OpLeakyReluTest : public OperatorTest {
29
return torch::executor::aten::leaky_relu_outf(
30
context_, in, negative_slope, out);
31
}
32
-};
+ template <ScalarType DTYPE>
33
+ void test_leaky_relu_dtype() {
34
+ TensorFactory<DTYPE> tf;
35
+ Tensor in = tf.ones({2, 2});
36
+ Tensor out = tf.zeros({2, 2});
37
-TEST_F(OpLeakyReluTest, SanityCheck) {
- TensorFactory<ScalarType::Float> tf;
- Tensor in = tf.ones({2, 2});
- Tensor out = tf.zeros({2, 2});
38
+ Tensor ret = op_leaky_relu_out(in, -0.01, out);
39
- Tensor ret = op_leaky_relu_out(in, -0.01, out);
40
+ EXPECT_TENSOR_EQ(out, ret);
41
+ EXPECT_TENSOR_EQ(out, tf.ones({2, 2}));
42
+ }
43
+};
- EXPECT_TENSOR_EQ(out, ret);
- EXPECT_TENSOR_EQ(out, tf.ones({2, 2}));
+TEST_F(OpLeakyReluTest, SanityCheck) {
+#define TEST_ENTRY(ctype, dtype) test_leaky_relu_dtype<ScalarType::dtype>();
+ ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
+#undef TEST_ENTRY
0 commit comments