Skip to content

Commit 7337fe0

Browse files
swolchokYIWENX14
authored andcommitted
Support Half/BFloat16 in arange (#7791)
Partial fix for #7748.
1 parent be6802f commit 7337fe0

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

kernels/portable/cpu/op_arange.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ Tensor& arange_out(KernelRuntimeContext& ctx, const Scalar& end, Tensor& out) {
3939
InvalidArgument,
4040
out);
4141

42-
ET_SWITCH_REAL_TYPES(out.scalar_type(), ctx, "arange.out", CTYPE, [&]() {
42+
ET_SWITCH_REALHBF16_TYPES(out.scalar_type(), ctx, "arange.out", CTYPE, [&]() {
4343
auto out_data = out.mutable_data_ptr<CTYPE>();
4444
for (size_t i = 0; i < size; i++) {
4545
out_data[i] = static_cast<CTYPE>(i);
@@ -88,7 +88,7 @@ Tensor& arange_start_out(
8888
InvalidArgument,
8989
out);
9090

91-
ET_SWITCH_REAL_TYPES(
91+
ET_SWITCH_REALHBF16_TYPES(
9292
out.scalar_type(), ctx, "arange.start_out", CTYPE, [&]() {
9393
auto out_data = out.mutable_data_ptr<CTYPE>();
9494
for (size_t i = 0; i < size; i++) {

kernels/test/op_arange_test.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,9 @@ class OpArangeStartOutTest : public OperatorTest {
8686
};
8787

8888
/// A generic smoke test that works for any dtype that supports zeros().
89-
TEST_F(OpArangeOutTest, AllRealDtypesSupported) {
89+
TEST_F(OpArangeOutTest, AllRealHBF16DtypesSupported) {
9090
#define TEST_ENTRY(ctype, dtype) test_arange_dtype<ctype, ScalarType::dtype>();
91-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
91+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
9292
#undef TEST_ENTRY
9393
}
9494

@@ -164,10 +164,10 @@ TEST_F(OpArangeOutTest, DynamicShapeUnbound) {
164164
}
165165

166166
/// A generic smoke test that works for any dtype that supports zeros().
167-
TEST_F(OpArangeStartOutTest, AllRealDtypesSupported) {
167+
TEST_F(OpArangeStartOutTest, AllRealHBF16DtypesSupported) {
168168
#define TEST_ENTRY(ctype, dtype) \
169169
test_arange_start_dtype<ctype, ScalarType::dtype>();
170-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
170+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
171171
#undef TEST_ENTRY
172172
}
173173

0 commit comments

Comments
 (0)