Skip to content

Commit f5ba5bc

Browse files
authored
Support Half/BFloat16 in cdist (#7800)
Partial fix for #7748.
1 parent 091bc4a commit f5ba5bc

File tree

3 files changed

+97
-78
lines changed

3 files changed

+97
-78
lines changed

kernels/portable/cpu/op_cdist_forward.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ Tensor& _cdist_forward_out(
162162
ScalarType out_type = out.scalar_type();
163163
constexpr auto name = "_cdist_forward.out";
164164

165-
ET_SWITCH_FLOAT_TYPES(
165+
ET_SWITCH_FLOATHBF16_TYPES(
166166
out_type, ctx, name, CTYPE, [&] { cdist<CTYPE>(x1, x2, out, p); });
167167

168168
return out;

kernels/test/op_cdist_forward_test.cpp

Lines changed: 93 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -40,89 +40,105 @@ class OpCdistForwardOutTest : public ::testing::Test {
4040
// first.
4141
torch::executor::runtime_init();
4242
}
43-
};
4443

45-
TEST_F(OpCdistForwardOutTest, SmokeTest) {
46-
TensorFactory<ScalarType::Float> tfFloat;
44+
template <ScalarType DTYPE>
45+
void test_dtype() {
46+
TensorFactory<DTYPE> tf;
47+
48+
Tensor x1 = tf.make({2, 1, 4, 3}, {0, 1, 2, 3, 5, 4, 3, -3, 7, 1, 6, 2,
49+
-1, 5, 1, 1, -2, 1, 5, 4, 3, 2, -1, 5});
50+
Tensor x2 = tf.make(
51+
{1, 2, 5, 3}, {0, 1, 2, 3, 5, -3, 7, 1, 6, 2, -1, 5, 1, 1, -2,
52+
4, 3, 2, -1, 5, 1, 1, -2, 1, 5, 4, 3, 2, -1, 5});
53+
optional<int64_t> compute_mode = optional<int64_t>();
4754

48-
Tensor x1 =
49-
tfFloat.make({2, 1, 4, 3}, {0, 1, 2, 3, 5, 4, 3, -3, 7, 1, 6, 2,
50-
-1, 5, 1, 1, -2, 1, 5, 4, 3, 2, -1, 5});
51-
Tensor x2 = tfFloat.make(
52-
{1, 2, 5, 3}, {0, 1, 2, 3, 5, -3, 7, 1, 6, 2, -1, 5, 1, 1, -2,
53-
4, 3, 2, -1, 5, 1, 1, -2, 1, 5, 4, 3, 2, -1, 5});
54-
optional<int64_t> compute_mode = optional<int64_t>();
55+
Tensor out = tf.zeros({2, 2, 4, 5});
5556

56-
Tensor out = tfFloat.zeros({2, 2, 4, 5});
57+
Tensor l0 = tf.make(
58+
{2, 2, 4, 5},
59+
{0., 3., 2., 3., 2., 3., 1., 3., 3., 3., 3., 2., 3., 3., 3., 2.,
60+
3., 3., 3., 2., 2., 3., 3., 3., 3., 3., 2., 3., 3., 3., 3., 3.,
61+
3., 3., 3., 2., 3., 2., 3., 3., 3., 2., 3., 3., 3., 3., 3., 3.,
62+
3., 2., 3., 3., 3., 3., 3., 3., 3., 3., 0., 3., 3., 0., 2., 3.,
63+
3., 3., 2., 0., 3., 3., 3., 3., 3., 0., 3., 3., 3., 3., 3., 0.});
64+
op_cdist_forward_out(x1, x2, 0.0, compute_mode, out);
65+
EXPECT_TENSOR_CLOSE(out, l0);
5766

58-
Tensor l0 = tfFloat.make(
59-
{2, 2, 4, 5},
60-
{0., 3., 2., 3., 2., 3., 1., 3., 3., 3., 3., 2., 3., 3., 3., 2.,
61-
3., 3., 3., 2., 2., 3., 3., 3., 3., 3., 2., 3., 3., 3., 3., 3.,
62-
3., 3., 3., 2., 3., 2., 3., 3., 3., 2., 3., 3., 3., 3., 3., 3.,
63-
3., 2., 3., 3., 3., 3., 3., 3., 3., 3., 0., 3., 3., 0., 2., 3.,
64-
3., 3., 2., 0., 3., 3., 3., 3., 3., 0., 3., 3., 3., 3., 3., 0.});
65-
op_cdist_forward_out(x1, x2, 0.0, compute_mode, out);
66-
EXPECT_TENSOR_CLOSE(out, l0);
67+
Tensor l1 = tf.make(
68+
{2, 2, 4, 5},
69+
{0., 12., 11., 7., 5., 9., 7., 10., 8., 12., 12., 18., 9., 5.,
70+
15., 6., 8., 15., 11., 9., 6., 6., 5., 9., 7., 5., 7., 12.,
71+
4., 8., 12., 18., 9., 13., 5., 6., 4., 9., 7., 11., 6., 8.,
72+
17., 13., 9., 5., 13., 14., 6., 6., 9., 9., 8., 10., 12., 7.,
73+
15., 8., 0., 10., 8., 0., 9., 9., 13., 9., 9., 0., 12., 6.,
74+
3., 9., 12., 0., 10., 9., 13., 6., 10., 0.});
75+
op_cdist_forward_out(x1, x2, 1.0, compute_mode, out);
76+
EXPECT_TENSOR_CLOSE(out, l1);
6777

68-
Tensor l1 = tfFloat.make(
69-
{2, 2, 4, 5},
70-
{0., 12., 11., 7., 5., 9., 7., 10., 8., 12., 12., 18., 9., 5.,
71-
15., 6., 8., 15., 11., 9., 6., 6., 5., 9., 7., 5., 7., 12.,
72-
4., 8., 12., 18., 9., 13., 5., 6., 4., 9., 7., 11., 6., 8.,
73-
17., 13., 9., 5., 13., 14., 6., 6., 9., 9., 8., 10., 12., 7.,
74-
15., 8., 0., 10., 8., 0., 9., 9., 13., 9., 9., 0., 12., 6.,
75-
3., 9., 12., 0., 10., 9., 13., 6., 10., 0.});
76-
op_cdist_forward_out(x1, x2, 1.0, compute_mode, out);
77-
EXPECT_TENSOR_CLOSE(out, l1);
78+
Tensor l2 = tf.make(
79+
{2, 2, 4, 5},
80+
{0.00000000, 7.07106781, 8.06225777, 4.12310553, 4.12310553,
81+
5.38516474, 7.00000000, 6.00000000, 6.16441393, 7.48331499,
82+
7.07106781, 12.80624866, 5.74456263, 3.00000000, 10.04987526,
83+
5.09901953, 5.47722578, 8.77496433, 7.68114567, 6.40312433,
84+
4.47213602, 4.24264050, 3.31662488, 5.91608000, 4.12310553,
85+
3.00000000, 5.00000000, 7.87400770, 2.44948983, 6.16441393,
86+
7.87400770, 10.77032948, 6.40312433, 8.30662346, 3.00000000,
87+
4.24264050, 2.44948983, 8.06225777, 4.58257580, 7.68114567,
88+
4.24264050, 5.65685415, 10.24695110, 7.81024981, 5.38516474,
89+
3.31662488, 8.30662346, 8.36660004, 4.24264050, 4.24264050,
90+
5.91608000, 6.40312433, 4.69041586, 6.16441393, 7.07106781,
91+
4.12310553, 10.04987526, 5.47722578, 0.00000000, 7.34846926,
92+
5.47722578, 0.00000000, 7.28010988, 6.40312433, 7.81024981,
93+
5.91608000, 7.28010988, 0.00000000, 7.48331499, 4.24264050,
94+
1.73205078, 6.40312433, 7.48331499, 0.00000000, 6.16441393,
95+
5.38516474, 7.81024981, 4.24264050, 6.16441393, 0.00000000});
96+
op_cdist_forward_out(x1, x2, 2.0, compute_mode, out);
97+
EXPECT_TENSOR_CLOSE(out, l2);
7898

79-
Tensor l2 = tfFloat.make(
80-
{2, 2, 4, 5},
81-
{0.00000000, 7.07106781, 8.06225777, 4.12310553, 4.12310553,
82-
5.38516474, 7.00000000, 6.00000000, 6.16441393, 7.48331499,
83-
7.07106781, 12.80624866, 5.74456263, 3.00000000, 10.04987526,
84-
5.09901953, 5.47722578, 8.77496433, 7.68114567, 6.40312433,
85-
4.47213602, 4.24264050, 3.31662488, 5.91608000, 4.12310553,
86-
3.00000000, 5.00000000, 7.87400770, 2.44948983, 6.16441393,
87-
7.87400770, 10.77032948, 6.40312433, 8.30662346, 3.00000000,
88-
4.24264050, 2.44948983, 8.06225777, 4.58257580, 7.68114567,
89-
4.24264050, 5.65685415, 10.24695110, 7.81024981, 5.38516474,
90-
3.31662488, 8.30662346, 8.36660004, 4.24264050, 4.24264050,
91-
5.91608000, 6.40312433, 4.69041586, 6.16441393, 7.07106781,
92-
4.12310553, 10.04987526, 5.47722578, 0.00000000, 7.34846926,
93-
5.47722578, 0.00000000, 7.28010988, 6.40312433, 7.81024981,
94-
5.91608000, 7.28010988, 0.00000000, 7.48331499, 4.24264050,
95-
1.73205078, 6.40312433, 7.48331499, 0.00000000, 6.16441393,
96-
5.38516474, 7.81024981, 4.24264050, 6.16441393, 0.00000000});
97-
op_cdist_forward_out(x1, x2, 2.0, compute_mode, out);
98-
EXPECT_TENSOR_CLOSE(out, l2);
99+
Tensor l3 = tf.make(
100+
{2, 2, 4, 5},
101+
{0.00000000, 6.00000000, 7.41079521, 3.50339794, 4.02072573,
102+
4.62606478, 7.00000000, 5.14256334, 6.01846170, 6.60385466,
103+
6.00000000, 11.47758675, 5.05277443, 2.57128167, 9.28704357,
104+
5.01329803, 5.11722994, 7.39863634, 7.18551636, 5.73879337,
105+
4.16016769, 4.04124022, 3.07231688, 5.34848118, 3.50339794,
106+
2.57128167, 4.49794149, 7.23042679, 2.15443468, 6.01846170,
107+
6.99319077, 9.25212955, 6.08220196, 7.45903587, 2.57128167,
108+
3.77976322, 2.15443468, 8.00520515, 4.17933941, 7.18551636,
109+
4.04124022, 5.03968430, 8.88326645, 6.74599648, 4.62606478,
110+
3.07231688, 7.45903587, 7.16609573, 4.04124022, 3.77976322,
111+
5.34848118, 6.08220196, 3.95789170, 5.42883539, 6.00000000,
112+
3.50339794, 9.00000000, 5.11722994, 0.00000000, 7.06069660,
113+
5.11722994, 0.00000000, 7.05400419, 6.08220196, 6.74599648,
114+
5.34848118, 7.05400419, 0.00000000, 6.60385466, 4.04124022,
115+
1.44224954, 6.08220196, 6.60385466, 0.00000000, 5.42883539,
116+
4.62606478, 6.74599648, 4.04124022, 5.42883539, 0.00000000});
117+
op_cdist_forward_out(x1, x2, 3.0, compute_mode, out);
118+
if (DTYPE == ScalarType::BFloat16) {
119+
EXPECT_TENSOR_CLOSE_WITH_TOL(
120+
out,
121+
l3,
122+
1e-2,
123+
executorch::runtime::testing::internal::kDefaultBFloat16Atol);
124+
} else {
125+
EXPECT_TENSOR_CLOSE(out, l3);
126+
}
99127

100-
Tensor l3 = tfFloat.make(
101-
{2, 2, 4, 5},
102-
{0.00000000, 6.00000000, 7.41079521, 3.50339794, 4.02072573, 4.62606478,
103-
7.00000000, 5.14256334, 6.01846170, 6.60385466, 6.00000000, 11.47758675,
104-
5.05277443, 2.57128167, 9.28704357, 5.01329803, 5.11722994, 7.39863634,
105-
7.18551636, 5.73879337, 4.16016769, 4.04124022, 3.07231688, 5.34848118,
106-
3.50339794, 2.57128167, 4.49794149, 7.23042679, 2.15443468, 6.01846170,
107-
6.99319077, 9.25212955, 6.08220196, 7.45903587, 2.57128167, 3.77976322,
108-
2.15443468, 8.00520515, 4.17933941, 7.18551636, 4.04124022, 5.03968430,
109-
8.88326645, 6.74599648, 4.62606478, 3.07231688, 7.45903587, 7.16609573,
110-
4.04124022, 3.77976322, 5.34848118, 6.08220196, 3.95789170, 5.42883539,
111-
6.00000000, 3.50339794, 9.00000000, 5.11722994, 0.00000000, 7.06069660,
112-
5.11722994, 0.00000000, 7.05400419, 6.08220196, 6.74599648, 5.34848118,
113-
7.05400419, 0.00000000, 6.60385466, 4.04124022, 1.44224954, 6.08220196,
114-
6.60385466, 0.00000000, 5.42883539, 4.62606478, 6.74599648, 4.04124022,
115-
5.42883539, 0.00000000});
116-
op_cdist_forward_out(x1, x2, 3.0, compute_mode, out);
117-
EXPECT_TENSOR_CLOSE(out, l3);
128+
Tensor linf = tf.make(
129+
{2, 2, 4, 5},
130+
{0., 5., 7., 3., 4., 4., 7., 4., 6., 6., 5., 10., 4., 2., 9., 5.,
131+
5., 6., 7., 5., 4., 4., 3., 5., 3., 2., 4., 7., 2., 6., 6., 8.,
132+
6., 7., 2., 3., 2., 8., 4., 7., 4., 4., 8., 6., 4., 3., 7., 6.,
133+
4., 3., 5., 6., 3., 5., 5., 3., 8., 5., 0., 7., 5., 0., 7., 6.,
134+
6., 5., 7., 0., 6., 4., 1., 6., 6., 0., 5., 4., 6., 4., 5., 0.});
135+
op_cdist_forward_out(x1, x2, INFINITY, compute_mode, out);
136+
EXPECT_TENSOR_CLOSE(out, linf);
137+
}
138+
};
118139

119-
Tensor linf = tfFloat.make(
120-
{2, 2, 4, 5},
121-
{0., 5., 7., 3., 4., 4., 7., 4., 6., 6., 5., 10., 4., 2., 9., 5.,
122-
5., 6., 7., 5., 4., 4., 3., 5., 3., 2., 4., 7., 2., 6., 6., 8.,
123-
6., 7., 2., 3., 2., 8., 4., 7., 4., 4., 8., 6., 4., 3., 7., 6.,
124-
4., 3., 5., 6., 3., 5., 5., 3., 8., 5., 0., 7., 5., 0., 7., 6.,
125-
6., 5., 7., 0., 6., 4., 1., 6., 6., 0., 5., 4., 6., 4., 5., 0.});
126-
op_cdist_forward_out(x1, x2, INFINITY, compute_mode, out);
127-
EXPECT_TENSOR_CLOSE(out, linf);
140+
TEST_F(OpCdistForwardOutTest, SmokeTest) {
141+
#define TEST_ENTRY(ctype, dtype) test_dtype<ScalarType::dtype>();
142+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
143+
#undef TEST_ENTRY
128144
}

runtime/core/exec_aten/testing_util/tensor_util.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ double default_atol_for_type(ScalarType t) {
8080
if (t == ScalarType::Half) {
8181
return internal::kDefaultHalfAtol;
8282
}
83+
if (t == ScalarType::BFloat16) {
84+
return internal::kDefaultBFloat16Atol;
85+
}
8386
return internal::kDefaultAtol;
8487
}
8588
} // namespace

0 commit comments

Comments
 (0)