From 7379949c63d67b1bb41fee157f31799f60e3c4ca Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 21 Jan 2025 12:55:59 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- kernels/portable/cpu/op_constant_pad_nd.cpp | 21 +++++------ kernels/test/op_constant_pad_nd_test.cpp | 42 ++++++++++----------- 2 files changed, 31 insertions(+), 32 deletions(-) diff --git a/kernels/portable/cpu/op_constant_pad_nd.cpp b/kernels/portable/cpu/op_constant_pad_nd.cpp index 28b0c4b034b..328207d70f3 100644 --- a/kernels/portable/cpu/op_constant_pad_nd.cpp +++ b/kernels/portable/cpu/op_constant_pad_nd.cpp @@ -184,17 +184,16 @@ Tensor& constant_pad_nd_out( ScalarType in_type = in.scalar_type(); ScalarType value_type = utils::get_scalar_dtype(value); - ET_SWITCH_REAL_TYPES_AND( - Bool, in_type, ctx, "constant_pad_nd.out", CTYPE, [&]() { - CTYPE value_v; - ET_SWITCH_SCALAR_OBJ_TYPES( - value_type, ctx, "constant_pad_nd.out", CTYPE_VALUE, [&]() { - CTYPE_VALUE val; - utils::extract_scalar(value, &val); - value_v = static_cast(val); - }); - constant_pad_nd_out_impl(in, pad, value_v, out); - }); + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "constant_pad_nd.out", CTYPE, [&]() { + CTYPE value_v; + ET_SWITCH_SCALAR_OBJ_TYPES( + value_type, ctx, "constant_pad_nd.out", CTYPE_VALUE, [&]() { + CTYPE_VALUE val; + utils::extract_scalar(value, &val); + value_v = static_cast(val); + }); + constant_pad_nd_out_impl(in, pad, value_v, out); + }); return out; } diff --git a/kernels/test/op_constant_pad_nd_test.cpp b/kernels/test/op_constant_pad_nd_test.cpp index 5ddc310c895..8d5befb108d 100644 --- a/kernels/test/op_constant_pad_nd_test.cpp +++ b/kernels/test/op_constant_pad_nd_test.cpp @@ -50,7 +50,7 @@ class OpConstantPadNDOutTest : public OperatorTest { 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, - + 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, @@ -66,7 +66,7 @@ class OpConstantPadNDOutTest : public OperatorTest { 7, 5, 6, 7, 8, 7, 7, 1, 2, 3, 4, 7, 7, 5, 6, 7, 8, 7, - + 7, 1, 2, 3, 4, 7, 7, 5, 6, 7, 8, 7, 7, 1, 2, 3, 4, 7, @@ -98,7 +98,7 @@ class OpConstantPadNDOutTest : public OperatorTest { 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, - + 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, @@ -116,7 +116,7 @@ class OpConstantPadNDOutTest : public OperatorTest { 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, - + 7, 7, 7, 7, 7, 7, 7, 7, 1, 2, 3, 4, @@ -150,7 +150,7 @@ class OpConstantPadNDOutTest : public OperatorTest { 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, - + 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, @@ -166,12 +166,12 @@ class OpConstantPadNDOutTest : public OperatorTest { 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, - + 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, - + 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, @@ -203,7 +203,7 @@ class OpConstantPadNDOutTest : public OperatorTest { 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, - + 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, @@ -221,7 +221,7 @@ class OpConstantPadNDOutTest : public OperatorTest { 7, 7, 5, 6, 7, 8, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, - + 7, 7, 1, 2, 3, 4, 7, 7, 7, 5, 6, 7, 8, 7, 7, 7, 1, 2, 3, 4, 7, @@ -255,7 +255,7 @@ class OpConstantPadNDOutTest : public OperatorTest { 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, - + 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, @@ -271,12 +271,12 @@ class OpConstantPadNDOutTest : public OperatorTest { 7, 7, 5, 6, 7, 8, 7, 7, 7, 1, 2, 3, 4, 7, 7, 7, 5, 6, 7, 8, 7, - + 7, 7, 1, 2, 3, 4, 7, 7, 7, 5, 6, 7, 8, 7, 7, 7, 1, 2, 3, 4, 7, 7, 7, 5, 6, 7, 8, 7, - + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, @@ -308,7 +308,7 @@ class OpConstantPadNDOutTest : public OperatorTest { 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, - + 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, @@ -325,13 +325,13 @@ class OpConstantPadNDOutTest : public OperatorTest { 7, 7, 5, 6, 7, 8, 7, 7, 7, 1, 2, 3, 4, 7, 7, 7, 5, 6, 7, 8, 7, - + 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 2, 3, 4, 7, 7, 7, 5, 6, 7, 8, 7, 7, 7, 1, 2, 3, 4, 7, 7, 7, 5, 6, 7, 8, 7, - + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, @@ -353,7 +353,7 @@ TEST_F(OpConstantPadNDOutTest, TestPadDim2) { #define TEST_ENTRY(ctype, dtype) \ test_constant_pad_nd_out_dim2(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } @@ -361,7 +361,7 @@ TEST_F(OpConstantPadNDOutTest, TestPadDim1) { #define TEST_ENTRY(ctype, dtype) \ test_constant_pad_nd_out_dim1(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } @@ -369,7 +369,7 @@ TEST_F(OpConstantPadNDOutTest, TestPadDim0) { #define TEST_ENTRY(ctype, dtype) \ test_constant_pad_nd_out_dim0(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } @@ -377,7 +377,7 @@ TEST_F(OpConstantPadNDOutTest, TestPadDim1And2) { #define TEST_ENTRY(ctype, dtype) \ test_constant_pad_nd_out_dim12(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } @@ -385,7 +385,7 @@ TEST_F(OpConstantPadNDOutTest, TestPadDim0And2) { #define TEST_ENTRY(ctype, dtype) \ test_constant_pad_nd_out_dim02(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } @@ -393,7 +393,7 @@ TEST_F(OpConstantPadNDOutTest, TestPadDim0And1And2) { #define TEST_ENTRY(ctype, dtype) \ test_constant_pad_nd_out_dim012(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY }