@@ -112,6 +112,122 @@ class OpAddOutKernelTest : public OperatorTest {
112112 // tests.
113113 EXPECT_TENSOR_CLOSE (out, tf.make (sizes, /* data=*/ {2.5 , 3.5 , 5.75 , 10.125 }));
114114 }
115+
116+ template <ScalarType DTYPE>
117+ void test_broadcast_3D () {
118+ TensorFactory<DTYPE> tf_a;
119+
120+ Tensor a =
121+ tf_a.make ({2 , 2 , 3 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
122+ Tensor b = tf_a.make ({2 , 1 , 3 }, /* data=*/ {2 , 3 , 4 , 5 , 6 , 7 });
123+
124+ // Destination for output of mul.
125+ Tensor out =
126+ tf_a.make ({2 , 2 , 3 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
127+ Tensor expected = tf_a.make (
128+ {2 , 2 , 3 }, /* data=*/ {3 , 5 , 7 , 6 , 8 , 10 , 12 , 14 , 16 , 15 , 17 , 19 });
129+
130+ // Check that it matches the expected output.
131+ EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
132+ EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.0 , out), expected);
133+ }
134+
135+ template <ScalarType DTYPE>
136+ void test_broadcast_4D () {
137+ TensorFactory<DTYPE> tf_a;
138+
139+ Tensor a = tf_a.make (
140+ {2 , 2 , 3 , 5 },
141+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 ,
142+ 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 ,
143+ 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 , 40 , 41 , 42 , 43 , 44 , 45 ,
144+ 46 , 47 , 48 , 49 , 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 , 60 });
145+ Tensor b = tf_a.make (
146+ {2 , 1 , 3 , 5 },
147+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 ,
148+ 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 });
149+
150+ // Destination for output of mul.
151+ Tensor out = tf_a.zeros ({2 , 2 , 3 , 5 });
152+ Tensor expected = tf_a.make (
153+ {2 , 2 , 3 , 5 },
154+ /* data=*/ {2 , 4 , 6 , 8 , 10 , 12 , 14 , 16 , 18 , 20 , 22 , 24 , 26 , 28 , 30 ,
155+ 17 , 19 , 21 , 23 , 25 , 27 , 29 , 31 , 33 , 35 , 37 , 39 , 41 , 43 , 45 ,
156+ 47 , 49 , 51 , 53 , 55 , 57 , 59 , 61 , 63 , 65 , 67 , 69 , 71 , 73 , 75 ,
157+ 62 , 64 , 66 , 68 , 70 , 72 , 74 , 76 , 78 , 80 , 82 , 84 , 86 , 88 , 90 });
158+
159+ // Check that it matches the expected output.
160+ EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
161+ EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.0 , out), expected);
162+
163+ b = tf_a.make (
164+ {2 , 2 , 1 , 5 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ,
165+ 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 , 20 });
166+ out = tf_a.zeros ({2 , 2 , 3 , 5 });
167+ expected = tf_a.make (
168+ {2 , 2 , 3 , 5 },
169+ /* data=*/ {2 , 4 , 6 , 8 , 10 , 7 , 9 , 11 , 13 , 15 , 12 , 14 , 16 , 18 , 20 ,
170+ 22 , 24 , 26 , 28 , 30 , 27 , 29 , 31 , 33 , 35 , 32 , 34 , 36 , 38 , 40 ,
171+ 42 , 44 , 46 , 48 , 50 , 47 , 49 , 51 , 53 , 55 , 52 , 54 , 56 , 58 , 60 ,
172+ 62 , 64 , 66 , 68 , 70 , 67 , 69 , 71 , 73 , 75 , 72 , 74 , 76 , 78 , 80 });
173+
174+ // Check that it matches the expected output.
175+ EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
176+ EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.0 , out), expected);
177+ }
178+
179+ template <ScalarType DTYPE>
180+ void test_broadcast_last_dim () {
181+ TensorFactory<DTYPE> tf_a;
182+
183+ Tensor a =
184+ tf_a.make ({4 , 3 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
185+ Tensor b = tf_a.make ({4 , 1 }, /* data=*/ {2 , 3 , 4 , 5 });
186+
187+ // Destination for output of mul.
188+ Tensor out = tf_a.zeros ({4 , 3 });
189+ Tensor expected =
190+ tf_a.make ({4 , 3 }, /* data=*/ {3 , 4 , 5 , 7 , 8 , 9 , 11 , 12 , 13 , 15 , 16 , 17 });
191+
192+ // Check that it matches the expected output.
193+ EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
194+ EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.0 , out), expected);
195+
196+ a = tf_a.make ({2 , 2 , 3 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
197+ b = tf_a.make ({2 , 2 , 1 }, /* data=*/ {2 , 3 , 4 , 5 });
198+
199+ // Destination for output of mul.
200+ out = tf_a.zeros ({2 , 2 , 3 });
201+ expected = tf_a.make (
202+ {2 , 2 , 3 }, /* data=*/ {3 , 4 , 5 , 7 , 8 , 9 , 11 , 12 , 13 , 15 , 16 , 17 });
203+
204+ // Check that it matches the expected output.
205+ EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
206+ EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.0 , out), expected);
207+
208+ a = tf_a.make (
209+ {2 , 2 , 3 , 5 },
210+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 ,
211+ 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 ,
212+ 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 , 40 , 41 , 42 , 43 , 44 , 45 ,
213+ 46 , 47 , 48 , 49 , 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 , 60 });
214+ b = tf_a.make (
215+ {2 , 2 , 3 , 1 },
216+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
217+
218+ // Destination for output of mul.
219+ out = tf_a.zeros ({2 , 2 , 3 , 5 });
220+ expected = tf_a.make (
221+ {2 , 2 , 3 , 5 },
222+ /* data=*/ {2 , 3 , 4 , 5 , 6 , 8 , 9 , 10 , 11 , 12 , 14 , 15 , 16 , 17 , 18 ,
223+ 20 , 21 , 22 , 23 , 24 , 26 , 27 , 28 , 29 , 30 , 32 , 33 , 34 , 35 , 36 ,
224+ 38 , 39 , 40 , 41 , 42 , 44 , 45 , 46 , 47 , 48 , 50 , 51 , 52 , 53 , 54 ,
225+ 56 , 57 , 58 , 59 , 60 , 62 , 63 , 64 , 65 , 66 , 68 , 69 , 70 , 71 , 72 });
226+
227+ // Check that it matches the expected output.
228+ EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
229+ EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.0 , out), expected);
230+ }
115231};
116232
117233class OpAddScalarOutKernelTest : public OperatorTest {
@@ -371,6 +487,23 @@ TEST_F(OpAddOutKernelTest, BroadcastOneElementRank0Tensor) {
371487 EXPECT_TENSOR_EQ (out, ret);
372488}
373489
490+ TEST_F (OpAddOutKernelTest, BroadcastNDTest) {
491+ // Test 3D tensors
492+ test_broadcast_3D<ScalarType::Float>();
493+ test_broadcast_3D<ScalarType::Half>();
494+ test_broadcast_3D<ScalarType::BFloat16>();
495+
496+ // Test 4D tensors
497+ test_broadcast_4D<ScalarType::Float>();
498+ test_broadcast_4D<ScalarType::Half>();
499+ test_broadcast_4D<ScalarType::BFloat16>();
500+
501+ // Test broadcasting on the last dimension
502+ test_broadcast_last_dim<ScalarType::Float>();
503+ test_broadcast_last_dim<ScalarType::Half>();
504+ test_broadcast_last_dim<ScalarType::BFloat16>();
505+ }
506+
374507//
375508// Death Tests
376509//
0 commit comments