@@ -220,6 +220,61 @@ class OpMulOutTest : public OperatorTest {
220220 EXPECT_TENSOR_CLOSE (op_mul_out (b, a, out), expected);
221221 }
222222
223+ template <ScalarType DTYPE>
224+ void test_broadcast_last_dim () {
225+ TensorFactory<DTYPE> tf_a;
226+
227+ Tensor a =
228+ tf_a.make ({4 , 3 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
229+ Tensor b = tf_a.make ({4 , 1 }, /* data=*/ {2 , 3 , 4 , 5 });
230+
231+ // Destination for output of mul.
232+ Tensor out = tf_a.zeros ({4 , 3 });
233+ Tensor expected = tf_a.make (
234+ {4 , 3 }, /* data=*/ {2 , 4 , 6 , 12 , 15 , 18 , 28 , 32 , 36 , 50 , 55 , 60 });
235+
236+ // Check that it matches the expected output.
237+ EXPECT_TENSOR_CLOSE (op_mul_out (a, b, out), expected);
238+ EXPECT_TENSOR_CLOSE (op_mul_out (b, a, out), expected);
239+
240+ a =
241+ tf_a.make ({2 , 2 , 3 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
242+ b = tf_a.make ({2 , 2 , 1 }, /* data=*/ {2 , 3 , 4 , 5 });
243+
244+ // Destination for output of mul.
245+ out = tf_a.zeros ({2 , 2 , 3 });
246+ expected = tf_a.make (
247+ {2 , 2 , 3 }, /* data=*/ {2 , 4 , 6 , 12 , 15 , 18 , 28 , 32 , 36 , 50 , 55 , 60 });
248+
249+ // Check that it matches the expected output.
250+ EXPECT_TENSOR_CLOSE (op_mul_out (a, b, out), expected);
251+ EXPECT_TENSOR_CLOSE (op_mul_out (b, a, out), expected);
252+
253+ a = tf_a.make (
254+ {2 , 2 , 3 , 5 },
255+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 ,
256+ 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 ,
257+ 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 , 40 , 41 , 42 , 43 , 44 , 45 ,
258+ 46 , 47 , 48 , 49 , 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 , 60 });
259+ b = tf_a.make (
260+ {2 , 2 , 3 , 1 },
261+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
262+
263+ // Destination for output of mul.
264+ out = tf_a.zeros ({2 , 2 , 3 , 5 });
265+ expected = tf_a.make (
266+ {2 , 2 , 3 , 5 },
267+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 12 , 14 , 16 , 18 , 20 , 33 , 36 ,
268+ 39 , 42 , 45 , 64 , 68 , 72 , 76 , 80 , 105 , 110 , 115 , 120 ,
269+ 125 , 156 , 162 , 168 , 174 , 180 , 217 , 224 , 231 , 238 , 245 , 288 ,
270+ 296 , 304 , 312 , 320 , 369 , 378 , 387 , 396 , 405 , 460 , 470 , 480 ,
271+ 490 , 500 , 561 , 572 , 583 , 594 , 605 , 672 , 684 , 696 , 708 , 720 });
272+
273+ // Check that it matches the expected output.
274+ EXPECT_TENSOR_CLOSE (op_mul_out (a, b, out), expected);
275+ EXPECT_TENSOR_CLOSE (op_mul_out (b, a, out), expected);
276+ }
277+
223278 template <ScalarType DTYPE>
224279 void test_broadcast_b2a () {
225280 TensorFactory<DTYPE> tf_a;
@@ -392,6 +447,18 @@ TEST_F(OpMulOutTest, BroadcastNDTest) {
392447 test_broadcast_4D<ScalarType::Float>();
393448 test_broadcast_4D<ScalarType::Half>();
394449 test_broadcast_4D<ScalarType::BFloat16>();
450+
451+ // Test broadcasting on the last dimension
452+ test_broadcast_last_dim<ScalarType::Float>();
453+ test_broadcast_last_dim<ScalarType::Half>();
454+ test_broadcast_last_dim<ScalarType::BFloat16>();
455+ }
456+
457+ TEST_F (OpMulOutTest, BroadcastLastDimTest) {
458+ // Test broadcasting on the last dimension
459+ test_broadcast_last_dim<ScalarType::Float>();
460+ test_broadcast_last_dim<ScalarType::Half>();
461+ test_broadcast_last_dim<ScalarType::BFloat16>();
395462}
396463
397464// Broadcast tensor a and b's size to a new size c.
0 commit comments