@@ -99,6 +99,109 @@ class OpSubOutTest : public OperatorTest {
9999 EXPECT_TENSOR_CLOSE (out, tf.make (sizes, /* data=*/ {0.1 , 1.2 , 3.4 , 7.8 }));
100100 }
101101
102+ template <ScalarType DTYPE>
103+ void test_broadcast_3D () {
104+ TensorFactory<DTYPE> tf_a;
105+
106+ Tensor a =
107+ tf_a.make ({2 , 2 , 3 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
108+ Tensor b = tf_a.make ({2 , 1 , 3 }, /* data=*/ {2 , 3 , 4 , 5 , 6 , 7 });
109+
110+ // Destination for output of mul.
111+ Tensor out =
112+ tf_a.make ({2 , 2 , 3 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
113+ Tensor expected =
114+ tf_a.make ({2 , 2 , 3 }, /* data=*/ {-1 , -1 , -1 , 2 , 2 , 2 , 2 , 2 , 2 , 5 , 5 , 5 });
115+
116+ // Check that it matches the expected output.
117+ EXPECT_TENSOR_CLOSE (op_sub_out (a, b, 1.0 , out), expected);
118+ // b - a * 1.5 output should be
119+ expected = tf_a.make (
120+ {2 , 2 , 3 },
121+ /* data=*/
122+ {0.5 ,
123+ 0.0 ,
124+ -0.5 ,
125+ -4.0 ,
126+ -4.5 ,
127+ -5.0 ,
128+ -5.5 ,
129+ -6.0 ,
130+ -6.5 ,
131+ -10.0 ,
132+ -10.5 ,
133+ -11.0 });
134+ EXPECT_TENSOR_CLOSE (op_sub_out (b, a, 1.5 , out), expected);
135+ }
136+
137+ template <ScalarType DTYPE>
138+ void test_broadcast_4D () {
139+ TensorFactory<DTYPE> tf_a;
140+
141+ Tensor a = tf_a.make (
142+ {2 , 2 , 3 , 5 },
143+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 ,
144+ 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 ,
145+ 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 , 40 , 41 , 42 , 43 , 44 , 45 ,
146+ 46 , 47 , 48 , 49 , 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 , 60 });
147+ Tensor b = tf_a.make (
148+ {2 , 1 , 3 , 5 },
149+ /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 ,
150+ 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 });
151+
152+ // Destination for output of mul.
153+ Tensor out = tf_a.zeros ({2 , 2 , 3 , 5 });
154+ Tensor expected = tf_a.make (
155+ {2 , 2 , 3 , 5 },
156+ /* data=*/ {0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
157+ 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 ,
158+ 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 ,
159+ 30 , 30 , 30 , 30 , 30 , 30 , 30 , 30 , 30 , 30 , 30 , 30 , 30 , 30 , 30 });
160+
161+ // Check that it matches the expected output.
162+ EXPECT_TENSOR_CLOSE (op_sub_out (a, b, 1.0 , out), expected);
163+ expected = tf_a.make (
164+ {2 , 2 , 3 , 5 },
165+ /* data=*/ {0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
166+ 0 , 0 , 0 , -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 ,
167+ -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 ,
168+ -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 , -30 , -30 , -30 ,
169+ -30 , -30 , -30 , -30 , -30 , -30 , -30 , -30 , -30 , -30 , -30 , -30 });
170+ EXPECT_TENSOR_CLOSE (op_sub_out (b, a, 1.0 , out), expected);
171+
172+ b = tf_a.make (
173+ {2 , 2 , 1 , 5 }, /* data=*/ {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ,
174+ 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 , 20 });
175+ out = tf_a.zeros ({2 , 2 , 3 , 5 });
176+ expected = tf_a.make (
177+ {2 , 2 , 3 , 5 },
178+ /* data=*/ {0 , 0 , 0 , 0 , 0 , 5 , 5 , 5 , 5 , 5 , 10 , 10 , 10 , 10 , 10 ,
179+ 10 , 10 , 10 , 10 , 10 , 15 , 15 , 15 , 15 , 15 , 20 , 20 , 20 , 20 , 20 ,
180+ 20 , 20 , 20 , 20 , 20 , 25 , 25 , 25 , 25 , 25 , 30 , 30 , 30 , 30 , 30 ,
181+ 30 , 30 , 30 , 30 , 30 , 35 , 35 , 35 , 35 , 35 , 40 , 40 , 40 , 40 , 40 });
182+
183+ // Check that it matches the expected output.
184+ EXPECT_TENSOR_CLOSE (op_sub_out (a, b, 1.0 , out), expected);
185+ expected = tf_a.make (
186+ {2 , 2 , 3 , 5 },
187+ /* data=*/ {-0.5000 , -1.0000 , -1.5000 , -2.0000 , -2.5000 ,
188+ -8.0000 , -8.5000 , -9.0000 , -9.5000 , -10.0000 ,
189+ -15.5000 , -16.0000 , -16.5000 , -17.0000 , -17.5000 ,
190+
191+ -18.0000 , -18.5000 , -19.0000 , -19.5000 , -20.0000 ,
192+ -25.5000 , -26.0000 , -26.5000 , -27.0000 , -27.5000 ,
193+ -33.0000 , -33.5000 , -34.0000 , -34.5000 , -35.0000 ,
194+
195+ -35.5000 , -36.0000 , -36.5000 , -37.0000 , -37.5000 ,
196+ -43.0000 , -43.5000 , -44.0000 , -44.5000 , -45.0000 ,
197+ -50.5000 , -51.0000 , -51.5000 , -52.0000 , -52.5000 ,
198+
199+ -53.0000 , -53.5000 , -54.0000 , -54.5000 , -55.0000 ,
200+ -60.5000 , -61.0000 , -61.5000 , -62.0000 , -62.5000 ,
201+ -68.0000 , -68.5000 , -69.0000 , -69.5000 , -70.0000 });
202+ EXPECT_TENSOR_CLOSE (op_sub_out (b, a, 1.5 , out), expected);
203+ }
204+
102205 void test_sub_enumerate_a_types () {
103206#define ENUMERATE_TEST_ENTRY (ctype, dtype ) \
104207 test_sub_enumerate_b_types<ScalarType::dtype>();
@@ -237,6 +340,19 @@ TEST_F(OpSubOutTest, BroadcastScalarRank0Supported) {
237340 EXPECT_TENSOR_EQ (out, ret);
238341}
239342
343+ TEST_F (OpSubOutTest, BroadcastNDTest) {
344+ // Test 3D tensors
345+ test_broadcast_3D<ScalarType::Float>();
346+ test_broadcast_3D<ScalarType::Half>();
347+ // Sub doesnt yet support BFloat16
348+ // test_broadcast_3D<ScalarType::BFloat16>();
349+
350+ // Test 4D tensors
351+ test_broadcast_4D<ScalarType::Float>();
352+ test_broadcast_4D<ScalarType::Half>();
353+ // test_broadcast_4D<ScalarType::BFloat16>();
354+ }
355+
240356//
241357// Death Tests
242358//
0 commit comments