@@ -112,6 +112,125 @@ class OpAddOutKernelTest : public OperatorTest {
112112 // tests.
113113 EXPECT_TENSOR_CLOSE (out, tf.make (sizes, /* data=*/ {2.5 , 3.5 , 5.75 , 10.125 }));
114114 }
115+
116+ template <ScalarType DTYPE>
117+ void test_broadcast_3D () {
118+ TensorFactory<DTYPE> tf_a;
119+
120+ Tensor a =
121+ tf_a.make ({2 , 2 , 3 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
122+ Tensor b = tf_a.make ({2 , 1 , 3 }, /* data=*/ {2 , 3 , 4 , 5 , 6 , 7 });
123+
124+ // Destination for output of mul.
125+ Tensor out =
126+ tf_a.make ({2 , 2 , 3 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
127+ Tensor expected = tf_a.make (
128+ {2 , 2 , 3 }, /* data=*/ {3 , 5 , 7 , 6 , 8 , 10 , 12 , 14 , 16 , 15 , 17 , 19 });
129+
130+ // Check that it matches the expected output.
131+ EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
132+ expected = tf_a.make (
133+ {2 , 2 , 3 },
134+ /* data=*/ {3.5 , 6 , 8.5 , 8 , 10.5 , 13 , 15.5 , 18 , 20.5 , 20 , 22.5 , 25 });
135+ EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.5 , out), expected);
136+ }
137+
138+ template <ScalarType DTYPE>
139+ void test_broadcast_4D () {
140+ TensorFactory<DTYPE> tf_a;
141+
142+ Tensor a = tf_a.make (
143+ {2 , 2 , 3 , 5 },
144+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 ,
145+ 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 ,
146+ 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 , 40 , 41 , 42 , 43 , 44 , 45 ,
147+ 46 , 47 , 48 , 49 , 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 , 60 });
148+ Tensor b = tf_a.make (
149+ {2 , 1 , 3 , 5 },
150+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 ,
151+ 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 });
152+
153+ // Destination for output of mul.
154+ Tensor out = tf_a.zeros ({2 , 2 , 3 , 5 });
155+ Tensor expected = tf_a.make (
156+ {2 , 2 , 3 , 5 },
157+ /* data=*/ {2 , 4 , 6 , 8 , 10 , 12 , 14 , 16 , 18 , 20 , 22 , 24 , 26 , 28 , 30 ,
158+ 17 , 19 , 21 , 23 , 25 , 27 , 29 , 31 , 33 , 35 , 37 , 39 , 41 , 43 , 45 ,
159+ 47 , 49 , 51 , 53 , 55 , 57 , 59 , 61 , 63 , 65 , 67 , 69 , 71 , 73 , 75 ,
160+ 62 , 64 , 66 , 68 , 70 , 72 , 74 , 76 , 78 , 80 , 82 , 84 , 86 , 88 , 90 });
161+
162+ // Check that it matches the expected output.
163+ EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
164+ EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.0 , out), expected);
165+
166+ b = tf_a.make (
167+ {2 , 2 , 1 , 5 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ,
168+ 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 , 20 });
169+ out = tf_a.zeros ({2 , 2 , 3 , 5 });
170+ expected = tf_a.make (
171+ {2 , 2 , 3 , 5 },
172+ /* data=*/ {2 , 4 , 6 , 8 , 10 , 7 , 9 , 11 , 13 , 15 , 12 , 14 , 16 , 18 , 20 ,
173+ 22 , 24 , 26 , 28 , 30 , 27 , 29 , 31 , 33 , 35 , 32 , 34 , 36 , 38 , 40 ,
174+ 42 , 44 , 46 , 48 , 50 , 47 , 49 , 51 , 53 , 55 , 52 , 54 , 56 , 58 , 60 ,
175+ 62 , 64 , 66 , 68 , 70 , 67 , 69 , 71 , 73 , 75 , 72 , 74 , 76 , 78 , 80 });
176+
177+ // Check that it matches the expected output.
178+ EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
179+ EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.0 , out), expected);
180+ }
181+
182+ template <ScalarType DTYPE>
183+ void test_broadcast_last_dim () {
184+ TensorFactory<DTYPE> tf_a;
185+
186+ Tensor a =
187+ tf_a.make ({4 , 3 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
188+ Tensor b = tf_a.make ({4 , 1 }, /* data=*/ {2 , 3 , 4 , 5 });
189+
190+ // Destination for output of mul.
191+ Tensor out = tf_a.zeros ({4 , 3 });
192+ Tensor expected =
193+ tf_a.make ({4 , 3 }, /* data=*/ {3 , 4 , 5 , 7 , 8 , 9 , 11 , 12 , 13 , 15 , 16 , 17 });
194+
195+ // Check that it matches the expected output.
196+ EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
197+ EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.0 , out), expected);
198+
199+ a = tf_a.make ({2 , 2 , 3 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
200+ b = tf_a.make ({2 , 2 , 1 }, /* data=*/ {2 , 3 , 4 , 5 });
201+
202+ // Destination for output of mul.
203+ out = tf_a.zeros ({2 , 2 , 3 });
204+ expected = tf_a.make (
205+ {2 , 2 , 3 }, /* data=*/ {3 , 4 , 5 , 7 , 8 , 9 , 11 , 12 , 13 , 15 , 16 , 17 });
206+
207+ // Check that it matches the expected output.
208+ EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
209+ EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.0 , out), expected);
210+
211+ a = tf_a.make (
212+ {2 , 2 , 3 , 5 },
213+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 ,
214+ 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 ,
215+ 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 , 40 , 41 , 42 , 43 , 44 , 45 ,
216+ 46 , 47 , 48 , 49 , 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 , 60 });
217+ b = tf_a.make (
218+ {2 , 2 , 3 , 1 },
219+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
220+
221+ // Destination for output of mul.
222+ out = tf_a.zeros ({2 , 2 , 3 , 5 });
223+ expected = tf_a.make (
224+ {2 , 2 , 3 , 5 },
225+ /* data=*/ {2 , 3 , 4 , 5 , 6 , 8 , 9 , 10 , 11 , 12 , 14 , 15 , 16 , 17 , 18 ,
226+ 20 , 21 , 22 , 23 , 24 , 26 , 27 , 28 , 29 , 30 , 32 , 33 , 34 , 35 , 36 ,
227+ 38 , 39 , 40 , 41 , 42 , 44 , 45 , 46 , 47 , 48 , 50 , 51 , 52 , 53 , 54 ,
228+ 56 , 57 , 58 , 59 , 60 , 62 , 63 , 64 , 65 , 66 , 68 , 69 , 70 , 71 , 72 });
229+
230+ // Check that it matches the expected output.
231+ EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
232+ EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.0 , out), expected);
233+ }
115234};
116235
117236class OpAddScalarOutKernelTest : public OperatorTest {
@@ -371,6 +490,23 @@ TEST_F(OpAddOutKernelTest, BroadcastOneElementRank0Tensor) {
371490 EXPECT_TENSOR_EQ (out, ret);
372491}
373492
493+ TEST_F (OpAddOutKernelTest, BroadcastNDTest) {
494+ // Test 3D tensors
495+ test_broadcast_3D<ScalarType::Float>();
496+ test_broadcast_3D<ScalarType::Half>();
497+ test_broadcast_3D<ScalarType::BFloat16>();
498+
499+ // Test 4D tensors
500+ test_broadcast_4D<ScalarType::Float>();
501+ test_broadcast_4D<ScalarType::Half>();
502+ test_broadcast_4D<ScalarType::BFloat16>();
503+
504+ // Test broadcasting on the last dimension
505+ test_broadcast_last_dim<ScalarType::Float>();
506+ test_broadcast_last_dim<ScalarType::Half>();
507+ test_broadcast_last_dim<ScalarType::BFloat16>();
508+ }
509+
374510//
375511// Death Tests
376512//
0 commit comments