@@ -72,7 +72,7 @@ class OpMulOutTest : public OperatorTest {
7272#define ENUMERATE_TEST_ENTRY (ctype, dtype ) \
7373 test_mul_enumerate_out_types<DTYPE_A, ScalarType::dtype>();
7474
75- ET_FORALL_REAL_TYPES_AND (Half, ENUMERATE_TEST_ENTRY)
75+ ET_FORALL_REALHBF16_TYPES ( ENUMERATE_TEST_ENTRY)
7676
7777#undef ENUMERATE_TEST_ENTRY
7878 }
@@ -89,29 +89,99 @@ class OpMulOutTest : public OperatorTest {
8989
9090 // Multiply two tensors
9191 op_mul_out (
92- tf.make (sizes, /* data=*/ {1.1 , 2.2 , 4.4 , 8.8 }), tf.ones (sizes), out);
93- EXPECT_TENSOR_CLOSE (out, tf.make (sizes, /* data=*/ {1.1 , 2.2 , 4.4 , 8.8 }));
92+ tf.make (sizes, /* data=*/ {1.25 , 2.5 , 4.75 , 8.875 }), tf.ones (sizes), out);
93+ EXPECT_TENSOR_CLOSE (out, tf.make (sizes, /* data=*/ {1.25 , 2.5 , 4.75 , 8.875 }));
9494
9595 op_mul_out (
9696 tf.make (sizes, /* data=*/ {1.1 , 2.2 , 4.4 , 8.8 }), tf.zeros (sizes), out);
9797 EXPECT_TENSOR_CLOSE (out, tf.make (sizes, /* data=*/ {0.0 , 0.0 , 0.0 , 0.0 }));
9898
9999 op_mul_out (
100- tf.make (sizes, /* data=*/ {1.1 , 2.2 , 4.4 , 8.8 }),
101- tf.make (sizes, /* data=*/ {1.1 , 2.2 , 4.4 , 8.8 }),
100+ tf.make (sizes, /* data=*/ {1.25 , 2.5 , 4.75 , 8.875 }),
101+ tf.make (sizes, /* data=*/ {1.25 , 2.5 , 4.75 , 8.875 }),
102102 out);
103103 EXPECT_TENSOR_CLOSE (
104- out, tf.make (sizes, /* data=*/ {1.21 , 4.84 , 19.36 , 77.44 }));
104+ out, tf.make (sizes, /* data=*/ {1.5625 , 6.25 , 22.5625 , 78.765625 }));
105105 }
106106
107107 void test_mul_enumerate_a_types () {
108108#define ENUMERATE_TEST_ENTRY (ctype, dtype ) \
109109 test_mul_enumerate_b_types<ScalarType::dtype>();
110110
111- ET_FORALL_REAL_TYPES_AND (Half, ENUMERATE_TEST_ENTRY)
111+ ET_FORALL_REALHBF16_TYPES ( ENUMERATE_TEST_ENTRY)
112112
113113#undef ENUMERATE_TEST_ENTRY
114114 }
115+
116+ template <ScalarType DTYPE>
117+ void test_optimized_path_ignores_leading_1_dimensions () {
118+ TensorFactory<DTYPE> tf;
119+
120+ const std::vector<int32_t > sizes1 = {1 , 1 , 2 , 2 };
121+ const std::vector<int32_t > sizes2 = {1 , 2 , 2 };
122+
123+ // Destination for the mul.
124+ Tensor out = tf.zeros (sizes1);
125+
126+ // Multiply two tensors
127+ op_mul_out (
128+ tf.make (sizes1, /* data=*/ {1.1 , 2.2 , 4.4 , 8.8 }), tf.ones (sizes2), out);
129+ EXPECT_TENSOR_CLOSE (out, tf.make (sizes1, /* data=*/ {1.1 , 2.2 , 4.4 , 8.8 }));
130+ }
131+
132+ template <ScalarType DTYPE>
133+ void test_broadcast_a2b () {
134+ TensorFactory<DTYPE> tf_a;
135+
136+ std::vector<std::vector<int32_t >> b_sizeses = {
137+ {2 },
138+ {1 , 2 },
139+ };
140+ for (const auto & b_sizes : b_sizeses) {
141+ // a and b of different shapes
142+ Tensor a = tf_a.make ({2 , 2 }, /* data=*/ {1 , 2 , 3 , 4 });
143+ Tensor b = tf_a.make (b_sizes, /* data=*/ {2 , 2 });
144+
145+ // Destination for output of mul.
146+ Tensor out = tf_a.zeros ({2 , 2 });
147+
148+ // Check that it matches the expected output.
149+ EXPECT_TENSOR_CLOSE (
150+ op_mul_out (a, b, out), tf_a.make ({2 , 2 }, /* data=*/ {2 , 4 , 6 , 8 }));
151+ }
152+ }
153+
154+ template <ScalarType DTYPE>
155+ void test_broadcast_b2a () {
156+ TensorFactory<DTYPE> tf_a;
157+ // a and b of different shapes
158+ Tensor a = tf_a.make ({2 }, /* data=*/ {2 , 2 });
159+ Tensor b = tf_a.make ({2 , 2 }, /* data=*/ {1 , 2 , 3 , 4 });
160+
161+ // Destination for output of mul.
162+ Tensor out = tf_a.zeros ({2 , 2 });
163+
164+ // Check that it matches the expected output.
165+ EXPECT_TENSOR_CLOSE (
166+ op_mul_out (a, b, out), tf_a.make ({2 , 2 }, /* data=*/ {2 , 4 , 6 , 8 }));
167+ }
168+
169+ template <ScalarType DTYPE>
170+ void test_scalar_input_broadcast () {
171+ TensorFactory<DTYPE> tf_a;
172+
173+ // a is a 1d tensor and b is a scalar
174+ Tensor a = tf_a.make ({2 }, /* data=*/ {2 , 2 });
175+ Tensor b = tf_a.make ({}, /* data=*/ {2 });
176+
177+ // Destination for output of mul.
178+ Tensor out = tf_a.make ({2 }, /* data=*/ {2 , 2 });
179+ Tensor expected = tf_a.make ({2 }, /* data=*/ {4 , 4 });
180+
181+ // Check that it matches the expected output.
182+ EXPECT_TENSOR_CLOSE (op_mul_out (a, b, out), expected);
183+ EXPECT_TENSOR_CLOSE (op_mul_out (b, a, out), expected);
184+ }
115185};
116186
117187class OpMulScalarOutTest : public OperatorTest {
@@ -141,6 +211,14 @@ TEST_F(OpMulOutTest, DoubleTensors) {
141211 test_floating_point_mul_out<ScalarType::Double>();
142212}
143213
214+ TEST_F (OpMulOutTest, HalfTensors) {
215+ test_floating_point_mul_out<ScalarType::Half>();
216+ }
217+
218+ TEST_F (OpMulOutTest, BFloat16Tensors) {
219+ test_floating_point_mul_out<ScalarType::BFloat16>();
220+ }
221+
144222TEST_F (OpMulOutTest, BoolTensors) {
145223 TensorFactory<ScalarType::Bool> tf;
146224
@@ -166,18 +244,12 @@ TEST_F(OpMulOutTest, BoolTensors) {
166244}
167245
168246TEST_F (OpMulOutTest, OptimizedPathIgnoresLeading1Dimensions) {
169- TensorFactory<ScalarType::Float> tf;
247+ #define ENUMERATE_TEST_ENTRY (ctype, dtype ) \
248+ test_optimized_path_ignores_leading_1_dimensions<ScalarType::dtype>();
170249
171- const std::vector<int32_t > sizes1 = {1 , 1 , 2 , 2 };
172- const std::vector<int32_t > sizes2 = {1 , 2 , 2 };
250+ ET_FORALL_FLOATHBF16_TYPES (ENUMERATE_TEST_ENTRY);
173251
174- // Destination for the mul.
175- Tensor out = tf.zeros (sizes1);
176-
177- // Multiply two tensors
178- op_mul_out (
179- tf.make (sizes1, /* data=*/ {1.1 , 2.2 , 4.4 , 8.8 }), tf.ones (sizes2), out);
180- EXPECT_TENSOR_CLOSE (out, tf.make (sizes1, /* data=*/ {1.1 , 2.2 , 4.4 , 8.8 }));
252+ #undef ENUMERATE_TEST_ENTRY
181253}
182254
183255// Mismatched shape tests.
@@ -202,40 +274,16 @@ TEST_F(OpMulOutTest, MismatchedNonBroadcastableInputShapesDies) {
202274
203275// Broadcast tensor b's size to tensor a's size
204276TEST_F (OpMulOutTest, BroadcastA2BTest) {
205- TensorFactory<ScalarType::Int> tf_a;
206-
207- std::vector<std::vector<int32_t >> b_sizeses = {
208- {2 },
209- {1 , 2 },
210- };
211- for (const auto & b_sizes : b_sizeses) {
212- // a and b of different shapes
213- Tensor a = tf_a.make ({2 , 2 }, /* data=*/ {1 , 2 , 3 , 4 });
214- Tensor b = tf_a.make (b_sizes, /* data=*/ {2 , 2 });
215-
216- // Destination for output of mul.
217- Tensor out = tf_a.zeros ({2 , 2 });
218-
219- // Check that it matches the expected output.
220- EXPECT_TENSOR_CLOSE (
221- op_mul_out (a, b, out), tf_a.make ({2 , 2 }, /* data=*/ {2 , 4 , 6 , 8 }));
222- }
277+ test_broadcast_a2b<ScalarType::Int>();
278+ test_broadcast_a2b<ScalarType::Half>();
279+ test_broadcast_a2b<ScalarType::BFloat16>();
223280}
224281
225282// Broadcast tensor a's size to tensor b's size
226283TEST_F (OpMulOutTest, BroadcastB2ATest) {
227- TensorFactory<ScalarType::Int> tf_a;
228-
229- // a and b of different shapes
230- Tensor a = tf_a.make ({2 }, /* data=*/ {2 , 2 });
231- Tensor b = tf_a.make ({2 , 2 }, /* data=*/ {1 , 2 , 3 , 4 });
232-
233- // Destination for output of mul.
234- Tensor out = tf_a.zeros ({2 , 2 });
235-
236- // Check that it matches the expected output.
237- EXPECT_TENSOR_CLOSE (
238- op_mul_out (a, b, out), tf_a.make ({2 , 2 }, /* data=*/ {2 , 4 , 6 , 8 }));
284+ test_broadcast_b2a<ScalarType::Int>();
285+ test_broadcast_b2a<ScalarType::Half>();
286+ test_broadcast_b2a<ScalarType::BFloat16>();
239287}
240288
241289// Broadcast tensor a and b's size to a new size c.
@@ -256,19 +304,9 @@ TEST_F(OpMulOutTest, BroadcastAB2CTest) {
256304}
257305
258306TEST_F (OpMulOutTest, ScalarInputBroadcastTest) {
259- TensorFactory<ScalarType::Int> tf_a;
260-
261- // a is a 1d tensor and b is a scalar
262- Tensor a = tf_a.make ({2 }, /* data=*/ {2 , 2 });
263- Tensor b = tf_a.make ({}, /* data=*/ {2 });
264-
265- // Destination for output of mul.
266- Tensor out = tf_a.make ({2 }, /* data=*/ {2 , 2 });
267- Tensor expected = tf_a.make ({2 }, /* data=*/ {4 , 4 });
268-
269- // Check that it matches the expected output.
270- EXPECT_TENSOR_CLOSE (op_mul_out (a, b, out), expected);
271- EXPECT_TENSOR_CLOSE (op_mul_out (b, a, out), expected);
307+ test_scalar_input_broadcast<ScalarType::Int>();
308+ test_scalar_input_broadcast<ScalarType::Half>();
309+ test_scalar_input_broadcast<ScalarType::BFloat16>();
272310}
273311
274312TEST_F (OpMulOutTest, MismatchedOutputShapesDies) {
0 commit comments