@@ -40,89 +40,105 @@ class OpCdistForwardOutTest : public ::testing::Test {
4040 // first.
4141 torch::executor::runtime_init ();
4242 }
43- };
4443
45- TEST_F (OpCdistForwardOutTest, SmokeTest) {
46- TensorFactory<ScalarType::Float> tfFloat;
44+ template <ScalarType DTYPE>
45+ void test_dtype () {
46+ TensorFactory<DTYPE> tf;
47+
48+ Tensor x1 = tf.make ({2 , 1 , 4 , 3 }, {0 , 1 , 2 , 3 , 5 , 4 , 3 , -3 , 7 , 1 , 6 , 2 ,
49+ -1 , 5 , 1 , 1 , -2 , 1 , 5 , 4 , 3 , 2 , -1 , 5 });
50+ Tensor x2 = tf.make (
51+ {1 , 2 , 5 , 3 }, {0 , 1 , 2 , 3 , 5 , -3 , 7 , 1 , 6 , 2 , -1 , 5 , 1 , 1 , -2 ,
52+ 4 , 3 , 2 , -1 , 5 , 1 , 1 , -2 , 1 , 5 , 4 , 3 , 2 , -1 , 5 });
53+ optional<int64_t > compute_mode = optional<int64_t >();
4754
48- Tensor x1 =
49- tfFloat.make ({2 , 1 , 4 , 3 }, {0 , 1 , 2 , 3 , 5 , 4 , 3 , -3 , 7 , 1 , 6 , 2 ,
50- -1 , 5 , 1 , 1 , -2 , 1 , 5 , 4 , 3 , 2 , -1 , 5 });
51- Tensor x2 = tfFloat.make (
52- {1 , 2 , 5 , 3 }, {0 , 1 , 2 , 3 , 5 , -3 , 7 , 1 , 6 , 2 , -1 , 5 , 1 , 1 , -2 ,
53- 4 , 3 , 2 , -1 , 5 , 1 , 1 , -2 , 1 , 5 , 4 , 3 , 2 , -1 , 5 });
54- optional<int64_t > compute_mode = optional<int64_t >();
55+ Tensor out = tf.zeros ({2 , 2 , 4 , 5 });
5556
56- Tensor out = tfFloat.zeros ({2 , 2 , 4 , 5 });
57+ Tensor l0 = tf.make (
58+ {2 , 2 , 4 , 5 },
59+ {0 ., 3 ., 2 ., 3 ., 2 ., 3 ., 1 ., 3 ., 3 ., 3 ., 3 ., 2 ., 3 ., 3 ., 3 ., 2 .,
60+ 3 ., 3 ., 3 ., 2 ., 2 ., 3 ., 3 ., 3 ., 3 ., 3 ., 2 ., 3 ., 3 ., 3 ., 3 ., 3 .,
61+ 3 ., 3 ., 3 ., 2 ., 3 ., 2 ., 3 ., 3 ., 3 ., 2 ., 3 ., 3 ., 3 ., 3 ., 3 ., 3 .,
62+ 3 ., 2 ., 3 ., 3 ., 3 ., 3 ., 3 ., 3 ., 3 ., 3 ., 0 ., 3 ., 3 ., 0 ., 2 ., 3 .,
63+ 3 ., 3 ., 2 ., 0 ., 3 ., 3 ., 3 ., 3 ., 3 ., 0 ., 3 ., 3 ., 3 ., 3 ., 3 ., 0 .});
64+ op_cdist_forward_out (x1, x2, 0.0 , compute_mode, out);
65+ EXPECT_TENSOR_CLOSE (out, l0);
5766
58- Tensor l0 = tfFloat.make (
59- {2 , 2 , 4 , 5 },
60- {0 ., 3 ., 2 ., 3 ., 2 ., 3 ., 1 ., 3 ., 3 ., 3 ., 3 ., 2 ., 3 ., 3 ., 3 ., 2 .,
61- 3 ., 3 ., 3 ., 2 ., 2 ., 3 ., 3 ., 3 ., 3 ., 3 ., 2 ., 3 ., 3 ., 3 ., 3 ., 3 .,
62- 3 ., 3 ., 3 ., 2 ., 3 ., 2 ., 3 ., 3 ., 3 ., 2 ., 3 ., 3 ., 3 ., 3 ., 3 ., 3 .,
63- 3 ., 2 ., 3 ., 3 ., 3 ., 3 ., 3 ., 3 ., 3 ., 3 ., 0 ., 3 ., 3 ., 0 ., 2 ., 3 .,
64- 3 ., 3 ., 2 ., 0 ., 3 ., 3 ., 3 ., 3 ., 3 ., 0 ., 3 ., 3 ., 3 ., 3 ., 3 ., 0 .});
65- op_cdist_forward_out (x1, x2, 0.0 , compute_mode, out);
66- EXPECT_TENSOR_CLOSE (out, l0);
67+ Tensor l1 = tf.make (
68+ {2 , 2 , 4 , 5 },
69+ {0 ., 12 ., 11 ., 7 ., 5 ., 9 ., 7 ., 10 ., 8 ., 12 ., 12 ., 18 ., 9 ., 5 .,
70+ 15 ., 6 ., 8 ., 15 ., 11 ., 9 ., 6 ., 6 ., 5 ., 9 ., 7 ., 5 ., 7 ., 12 .,
71+ 4 ., 8 ., 12 ., 18 ., 9 ., 13 ., 5 ., 6 ., 4 ., 9 ., 7 ., 11 ., 6 ., 8 .,
72+ 17 ., 13 ., 9 ., 5 ., 13 ., 14 ., 6 ., 6 ., 9 ., 9 ., 8 ., 10 ., 12 ., 7 .,
73+ 15 ., 8 ., 0 ., 10 ., 8 ., 0 ., 9 ., 9 ., 13 ., 9 ., 9 ., 0 ., 12 ., 6 .,
74+ 3 ., 9 ., 12 ., 0 ., 10 ., 9 ., 13 ., 6 ., 10 ., 0 .});
75+ op_cdist_forward_out (x1, x2, 1.0 , compute_mode, out);
76+ EXPECT_TENSOR_CLOSE (out, l1);
6777
68- Tensor l1 = tfFloat.make (
69- {2 , 2 , 4 , 5 },
70- {0 ., 12 ., 11 ., 7 ., 5 ., 9 ., 7 ., 10 ., 8 ., 12 ., 12 ., 18 ., 9 ., 5 .,
71- 15 ., 6 ., 8 ., 15 ., 11 ., 9 ., 6 ., 6 ., 5 ., 9 ., 7 ., 5 ., 7 ., 12 .,
72- 4 ., 8 ., 12 ., 18 ., 9 ., 13 ., 5 ., 6 ., 4 ., 9 ., 7 ., 11 ., 6 ., 8 .,
73- 17 ., 13 ., 9 ., 5 ., 13 ., 14 ., 6 ., 6 ., 9 ., 9 ., 8 ., 10 ., 12 ., 7 .,
74- 15 ., 8 ., 0 ., 10 ., 8 ., 0 ., 9 ., 9 ., 13 ., 9 ., 9 ., 0 ., 12 ., 6 .,
75- 3 ., 9 ., 12 ., 0 ., 10 ., 9 ., 13 ., 6 ., 10 ., 0 .});
76- op_cdist_forward_out (x1, x2, 1.0 , compute_mode, out);
77- EXPECT_TENSOR_CLOSE (out, l1);
78+ Tensor l2 = tf.make (
79+ {2 , 2 , 4 , 5 },
80+ {0.00000000 , 7.07106781 , 8.06225777 , 4.12310553 , 4.12310553 ,
81+ 5.38516474 , 7.00000000 , 6.00000000 , 6.16441393 , 7.48331499 ,
82+ 7.07106781 , 12.80624866 , 5.74456263 , 3.00000000 , 10.04987526 ,
83+ 5.09901953 , 5.47722578 , 8.77496433 , 7.68114567 , 6.40312433 ,
84+ 4.47213602 , 4.24264050 , 3.31662488 , 5.91608000 , 4.12310553 ,
85+ 3.00000000 , 5.00000000 , 7.87400770 , 2.44948983 , 6.16441393 ,
86+ 7.87400770 , 10.77032948 , 6.40312433 , 8.30662346 , 3.00000000 ,
87+ 4.24264050 , 2.44948983 , 8.06225777 , 4.58257580 , 7.68114567 ,
88+ 4.24264050 , 5.65685415 , 10.24695110 , 7.81024981 , 5.38516474 ,
89+ 3.31662488 , 8.30662346 , 8.36660004 , 4.24264050 , 4.24264050 ,
90+ 5.91608000 , 6.40312433 , 4.69041586 , 6.16441393 , 7.07106781 ,
91+ 4.12310553 , 10.04987526 , 5.47722578 , 0.00000000 , 7.34846926 ,
92+ 5.47722578 , 0.00000000 , 7.28010988 , 6.40312433 , 7.81024981 ,
93+ 5.91608000 , 7.28010988 , 0.00000000 , 7.48331499 , 4.24264050 ,
94+ 1.73205078 , 6.40312433 , 7.48331499 , 0.00000000 , 6.16441393 ,
95+ 5.38516474 , 7.81024981 , 4.24264050 , 6.16441393 , 0.00000000 });
96+ op_cdist_forward_out (x1, x2, 2.0 , compute_mode, out);
97+ EXPECT_TENSOR_CLOSE (out, l2);
7898
79- Tensor l2 = tfFloat.make (
80- {2 , 2 , 4 , 5 },
81- {0.00000000 , 7.07106781 , 8.06225777 , 4.12310553 , 4.12310553 ,
82- 5.38516474 , 7.00000000 , 6.00000000 , 6.16441393 , 7.48331499 ,
83- 7.07106781 , 12.80624866 , 5.74456263 , 3.00000000 , 10.04987526 ,
84- 5.09901953 , 5.47722578 , 8.77496433 , 7.68114567 , 6.40312433 ,
85- 4.47213602 , 4.24264050 , 3.31662488 , 5.91608000 , 4.12310553 ,
86- 3.00000000 , 5.00000000 , 7.87400770 , 2.44948983 , 6.16441393 ,
87- 7.87400770 , 10.77032948 , 6.40312433 , 8.30662346 , 3.00000000 ,
88- 4.24264050 , 2.44948983 , 8.06225777 , 4.58257580 , 7.68114567 ,
89- 4.24264050 , 5.65685415 , 10.24695110 , 7.81024981 , 5.38516474 ,
90- 3.31662488 , 8.30662346 , 8.36660004 , 4.24264050 , 4.24264050 ,
91- 5.91608000 , 6.40312433 , 4.69041586 , 6.16441393 , 7.07106781 ,
92- 4.12310553 , 10.04987526 , 5.47722578 , 0.00000000 , 7.34846926 ,
93- 5.47722578 , 0.00000000 , 7.28010988 , 6.40312433 , 7.81024981 ,
94- 5.91608000 , 7.28010988 , 0.00000000 , 7.48331499 , 4.24264050 ,
95- 1.73205078 , 6.40312433 , 7.48331499 , 0.00000000 , 6.16441393 ,
96- 5.38516474 , 7.81024981 , 4.24264050 , 6.16441393 , 0.00000000 });
97- op_cdist_forward_out (x1, x2, 2.0 , compute_mode, out);
98- EXPECT_TENSOR_CLOSE (out, l2);
99+ Tensor l3 = tf.make (
100+ {2 , 2 , 4 , 5 },
101+ {0.00000000 , 6.00000000 , 7.41079521 , 3.50339794 , 4.02072573 ,
102+ 4.62606478 , 7.00000000 , 5.14256334 , 6.01846170 , 6.60385466 ,
103+ 6.00000000 , 11.47758675 , 5.05277443 , 2.57128167 , 9.28704357 ,
104+ 5.01329803 , 5.11722994 , 7.39863634 , 7.18551636 , 5.73879337 ,
105+ 4.16016769 , 4.04124022 , 3.07231688 , 5.34848118 , 3.50339794 ,
106+ 2.57128167 , 4.49794149 , 7.23042679 , 2.15443468 , 6.01846170 ,
107+ 6.99319077 , 9.25212955 , 6.08220196 , 7.45903587 , 2.57128167 ,
108+ 3.77976322 , 2.15443468 , 8.00520515 , 4.17933941 , 7.18551636 ,
109+ 4.04124022 , 5.03968430 , 8.88326645 , 6.74599648 , 4.62606478 ,
110+ 3.07231688 , 7.45903587 , 7.16609573 , 4.04124022 , 3.77976322 ,
111+ 5.34848118 , 6.08220196 , 3.95789170 , 5.42883539 , 6.00000000 ,
112+ 3.50339794 , 9.00000000 , 5.11722994 , 0.00000000 , 7.06069660 ,
113+ 5.11722994 , 0.00000000 , 7.05400419 , 6.08220196 , 6.74599648 ,
114+ 5.34848118 , 7.05400419 , 0.00000000 , 6.60385466 , 4.04124022 ,
115+ 1.44224954 , 6.08220196 , 6.60385466 , 0.00000000 , 5.42883539 ,
116+ 4.62606478 , 6.74599648 , 4.04124022 , 5.42883539 , 0.00000000 });
117+ op_cdist_forward_out (x1, x2, 3.0 , compute_mode, out);
118+ if (DTYPE == ScalarType::BFloat16) {
119+ EXPECT_TENSOR_CLOSE_WITH_TOL (
120+ out,
121+ l3,
122+ 1e-2 ,
123+ executorch::runtime::testing::internal::kDefaultBFloat16Atol );
124+ } else {
125+ EXPECT_TENSOR_CLOSE (out, l3);
126+ }
99127
100- Tensor l3 = tfFloat.make (
101- {2 , 2 , 4 , 5 },
102- {0.00000000 , 6.00000000 , 7.41079521 , 3.50339794 , 4.02072573 , 4.62606478 ,
103- 7.00000000 , 5.14256334 , 6.01846170 , 6.60385466 , 6.00000000 , 11.47758675 ,
104- 5.05277443 , 2.57128167 , 9.28704357 , 5.01329803 , 5.11722994 , 7.39863634 ,
105- 7.18551636 , 5.73879337 , 4.16016769 , 4.04124022 , 3.07231688 , 5.34848118 ,
106- 3.50339794 , 2.57128167 , 4.49794149 , 7.23042679 , 2.15443468 , 6.01846170 ,
107- 6.99319077 , 9.25212955 , 6.08220196 , 7.45903587 , 2.57128167 , 3.77976322 ,
108- 2.15443468 , 8.00520515 , 4.17933941 , 7.18551636 , 4.04124022 , 5.03968430 ,
109- 8.88326645 , 6.74599648 , 4.62606478 , 3.07231688 , 7.45903587 , 7.16609573 ,
110- 4.04124022 , 3.77976322 , 5.34848118 , 6.08220196 , 3.95789170 , 5.42883539 ,
111- 6.00000000 , 3.50339794 , 9.00000000 , 5.11722994 , 0.00000000 , 7.06069660 ,
112- 5.11722994 , 0.00000000 , 7.05400419 , 6.08220196 , 6.74599648 , 5.34848118 ,
113- 7.05400419 , 0.00000000 , 6.60385466 , 4.04124022 , 1.44224954 , 6.08220196 ,
114- 6.60385466 , 0.00000000 , 5.42883539 , 4.62606478 , 6.74599648 , 4.04124022 ,
115- 5.42883539 , 0.00000000 });
116- op_cdist_forward_out (x1, x2, 3.0 , compute_mode, out);
117- EXPECT_TENSOR_CLOSE (out, l3);
128+ Tensor linf = tf.make (
129+ {2 , 2 , 4 , 5 },
130+ {0 ., 5 ., 7 ., 3 ., 4 ., 4 ., 7 ., 4 ., 6 ., 6 ., 5 ., 10 ., 4 ., 2 ., 9 ., 5 .,
131+ 5 ., 6 ., 7 ., 5 ., 4 ., 4 ., 3 ., 5 ., 3 ., 2 ., 4 ., 7 ., 2 ., 6 ., 6 ., 8 .,
132+ 6 ., 7 ., 2 ., 3 ., 2 ., 8 ., 4 ., 7 ., 4 ., 4 ., 8 ., 6 ., 4 ., 3 ., 7 ., 6 .,
133+ 4 ., 3 ., 5 ., 6 ., 3 ., 5 ., 5 ., 3 ., 8 ., 5 ., 0 ., 7 ., 5 ., 0 ., 7 ., 6 .,
134+ 6 ., 5 ., 7 ., 0 ., 6 ., 4 ., 1 ., 6 ., 6 ., 0 ., 5 ., 4 ., 6 ., 4 ., 5 ., 0 .});
135+ op_cdist_forward_out (x1, x2, INFINITY, compute_mode, out);
136+ EXPECT_TENSOR_CLOSE (out, linf);
137+ }
138+ };
118139
119- Tensor linf = tfFloat.make (
120- {2 , 2 , 4 , 5 },
121- {0 ., 5 ., 7 ., 3 ., 4 ., 4 ., 7 ., 4 ., 6 ., 6 ., 5 ., 10 ., 4 ., 2 ., 9 ., 5 .,
122- 5 ., 6 ., 7 ., 5 ., 4 ., 4 ., 3 ., 5 ., 3 ., 2 ., 4 ., 7 ., 2 ., 6 ., 6 ., 8 .,
123- 6 ., 7 ., 2 ., 3 ., 2 ., 8 ., 4 ., 7 ., 4 ., 4 ., 8 ., 6 ., 4 ., 3 ., 7 ., 6 .,
124- 4 ., 3 ., 5 ., 6 ., 3 ., 5 ., 5 ., 3 ., 8 ., 5 ., 0 ., 7 ., 5 ., 0 ., 7 ., 6 .,
125- 6 ., 5 ., 7 ., 0 ., 6 ., 4 ., 1 ., 6 ., 6 ., 0 ., 5 ., 4 ., 6 ., 4 ., 5 ., 0 .});
126- op_cdist_forward_out (x1, x2, INFINITY, compute_mode, out);
127- EXPECT_TENSOR_CLOSE (out, linf);
140+ TEST_F (OpCdistForwardOutTest, SmokeTest) {
141+ #define TEST_ENTRY (ctype, dtype ) test_dtype<ScalarType::dtype>();
142+ ET_FORALL_FLOATHBF16_TYPES (TEST_ENTRY);
143+ #undef TEST_ENTRY
128144}
0 commit comments