@@ -220,6 +220,60 @@ class OpMulOutTest : public OperatorTest {
220220 EXPECT_TENSOR_CLOSE (op_mul_out (b, a, out), expected);
221221 }
222222
223+ template <ScalarType DTYPE>
224+ void test_broadcast_last_dim () {
225+ TensorFactory<DTYPE> tf_a;
226+
227+ Tensor a =
228+ tf_a.make ({4 , 3 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
229+ Tensor b = tf_a.make ({4 , 1 }, /* data=*/ {2 , 3 , 4 , 5 });
230+
231+ // Destination for output of mul.
232+ Tensor out = tf_a.zeros ({4 , 3 });
233+ Tensor expected = tf_a.make (
234+ {4 , 3 }, /* data=*/ {2 , 4 , 6 , 12 , 15 , 18 , 28 , 32 , 36 , 50 , 55 , 60 });
235+
236+ // Check that it matches the expected output.
237+ EXPECT_TENSOR_CLOSE (op_mul_out (a, b, out), expected);
238+ EXPECT_TENSOR_CLOSE (op_mul_out (b, a, out), expected);
239+
240+ a = tf_a.make ({2 , 2 , 3 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
241+ b = tf_a.make ({2 , 2 , 1 }, /* data=*/ {2 , 3 , 4 , 5 });
242+
243+ // Destination for output of mul.
244+ out = tf_a.zeros ({2 , 2 , 3 });
245+ expected = tf_a.make (
246+ {2 , 2 , 3 }, /* data=*/ {2 , 4 , 6 , 12 , 15 , 18 , 28 , 32 , 36 , 50 , 55 , 60 });
247+
248+ // Check that it matches the expected output.
249+ EXPECT_TENSOR_CLOSE (op_mul_out (a, b, out), expected);
250+ EXPECT_TENSOR_CLOSE (op_mul_out (b, a, out), expected);
251+
252+ a = tf_a.make (
253+ {2 , 2 , 3 , 5 },
254+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 ,
255+ 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 ,
256+ 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 , 40 , 41 , 42 , 43 , 44 , 45 ,
257+ 46 , 47 , 48 , 49 , 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 , 60 });
258+ b = tf_a.make (
259+ {2 , 2 , 3 , 1 },
260+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
261+
262+ // Destination for output of mul.
263+ out = tf_a.zeros ({2 , 2 , 3 , 5 });
264+ expected = tf_a.make (
265+ {2 , 2 , 3 , 5 },
266+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 12 , 14 , 16 , 18 , 20 , 33 , 36 ,
267+ 39 , 42 , 45 , 64 , 68 , 72 , 76 , 80 , 105 , 110 , 115 , 120 ,
268+ 125 , 156 , 162 , 168 , 174 , 180 , 217 , 224 , 231 , 238 , 245 , 288 ,
269+ 296 , 304 , 312 , 320 , 369 , 378 , 387 , 396 , 405 , 460 , 470 , 480 ,
270+ 490 , 500 , 561 , 572 , 583 , 594 , 605 , 672 , 684 , 696 , 708 , 720 });
271+
272+ // Check that it matches the expected output.
273+ EXPECT_TENSOR_CLOSE (op_mul_out (a, b, out), expected);
274+ EXPECT_TENSOR_CLOSE (op_mul_out (b, a, out), expected);
275+ }
276+
223277 template <ScalarType DTYPE>
224278 void test_broadcast_b2a () {
225279 TensorFactory<DTYPE> tf_a;
@@ -392,6 +446,18 @@ TEST_F(OpMulOutTest, BroadcastNDTest) {
392446 test_broadcast_4D<ScalarType::Float>();
393447 test_broadcast_4D<ScalarType::Half>();
394448 test_broadcast_4D<ScalarType::BFloat16>();
449+
450+ // Test broadcasting on the last dimension
451+ test_broadcast_last_dim<ScalarType::Float>();
452+ test_broadcast_last_dim<ScalarType::Half>();
453+ test_broadcast_last_dim<ScalarType::BFloat16>();
454+ }
455+
456+ TEST_F (OpMulOutTest, BroadcastLastDimTest) {
457+ // Test broadcasting on the last dimension
458+ test_broadcast_last_dim<ScalarType::Float>();
459+ test_broadcast_last_dim<ScalarType::Half>();
460+ test_broadcast_last_dim<ScalarType::BFloat16>();
395461}
396462
397463// Broadcast tensor a and b's size to a new size c.
0 commit comments