Skip to content

Commit ed01507

Browse files
committed
Update
[ghstack-poisoned]
1 parent 466d98f commit ed01507

File tree

2 files changed

+19
-12
lines changed

2 files changed

+19
-12
lines changed

kernels/portable/cpu/op_flip.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ Tensor& flip_out(
6666

6767
constexpr auto name = "flip.out";
6868

69-
ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] {
69+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] {
7070
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
7171
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
7272

kernels/test/op_flip_test.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,27 @@ class OpFlipOutTest : public ::testing::Test {
3333
// first.
3434
torch::executor::runtime_init();
3535
}
36+
37+
template <ScalarType DTYPE>
38+
void test_1d_dtype() {
39+
TensorFactory<DTYPE> tf;
40+
41+
Tensor input =
42+
tf.make({4, 1, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
43+
int64_t dims_data[1] = {-1};
44+
IntArrayRef dims = IntArrayRef(dims_data, 1);
45+
Tensor out = tf.zeros({4, 1, 3});
46+
Tensor out_expected =
47+
tf.make({4, 1, 3}, {3, 2, 1, 6, 5, 4, 9, 8, 7, 12, 11, 10});
48+
op_flip_out(input, dims, out);
49+
EXPECT_TENSOR_CLOSE(out, out_expected);
50+
}
3651
};
3752

3853
TEST_F(OpFlipOutTest, SmokeTest1Dim) {
39-
TensorFactory<ScalarType::Float> tfFloat;
40-
41-
Tensor input =
42-
tfFloat.make({4, 1, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
43-
int64_t dims_data[1] = {-1};
44-
IntArrayRef dims = IntArrayRef(dims_data, 1);
45-
Tensor out = tfFloat.zeros({4, 1, 3});
46-
Tensor out_expected =
47-
tfFloat.make({4, 1, 3}, {3, 2, 1, 6, 5, 4, 9, 8, 7, 12, 11, 10});
48-
op_flip_out(input, dims, out);
49-
EXPECT_TENSOR_CLOSE(out, out_expected);
54+
#define TEST_ENTRY(ctype, dtype) test_1d_dtype<ScalarType::dtype>();
55+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
56+
#undef TEST_ENTRY
5057
}
5158

5259
TEST_F(OpFlipOutTest, SmokeTest2Dims) {

0 commit comments

Comments
 (0)