@@ -99,6 +99,109 @@ class OpSubOutTest : public OperatorTest {
9999    EXPECT_TENSOR_CLOSE (out, tf.make (sizes, /* data=*/  {0.1 , 1.2 , 3.4 , 7.8 }));
100100  }
101101
102+   template  <ScalarType DTYPE>
103+   void  test_broadcast_3D () {
104+     TensorFactory<DTYPE> tf_a;
105+ 
106+     Tensor a =
107+         tf_a.make ({2 , 2 , 3 }, /* data=*/  {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
108+     Tensor b = tf_a.make ({2 , 1 , 3 }, /* data=*/  {2 , 3 , 4 , 5 , 6 , 7 });
109+ 
110+     //  Destination for output of mul.
111+     Tensor out =
112+         tf_a.make ({2 , 2 , 3 }, /* data=*/  {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
113+     Tensor expected =
114+         tf_a.make ({2 , 2 , 3 }, /* data=*/  {-1 , -1 , -1 , 2 , 2 , 2 , 2 , 2 , 2 , 5 , 5 , 5 });
115+ 
116+     //  Check that it matches the expected output.
117+     EXPECT_TENSOR_CLOSE (op_sub_out (a, b, 1.0 , out), expected);
118+     //  b - a * 1.5 output should be
119+     expected = tf_a.make (
120+         {2 , 2 , 3 },
121+         /* data=*/ 
122+         {0.5 ,
123+          0.0 ,
124+          -0.5 ,
125+          -4.0 ,
126+          -4.5 ,
127+          -5.0 ,
128+          -5.5 ,
129+          -6.0 ,
130+          -6.5 ,
131+          -10.0 ,
132+          -10.5 ,
133+          -11.0 });
134+     EXPECT_TENSOR_CLOSE (op_sub_out (b, a, 1.5 , out), expected);
135+   }
136+ 
137+   template  <ScalarType DTYPE>
138+   void  test_broadcast_4D () {
139+     TensorFactory<DTYPE> tf_a;
140+ 
141+     Tensor a = tf_a.make (
142+         {2 , 2 , 3 , 5 },
143+         /* data=*/  {1 ,  2 ,  3 ,  4 ,  5 ,  6 ,  7 ,  8 ,  9 ,  10 , 11 , 12 , 13 , 14 , 15 ,
144+                   16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 ,
145+                   31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 , 40 , 41 , 42 , 43 , 44 , 45 ,
146+                   46 , 47 , 48 , 49 , 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 , 60 });
147+     Tensor b = tf_a.make (
148+         {2 , 1 , 3 , 5 },
149+         /* data=*/  {1 ,  2 ,  3 ,  4 ,  5 ,  6 ,  7 ,  8 ,  9 ,  10 , 11 , 12 , 13 , 14 , 15 ,
150+                   16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 });
151+ 
152+     //  Destination for output of mul.
153+     Tensor out = tf_a.zeros ({2 , 2 , 3 , 5 });
154+     Tensor expected = tf_a.make (
155+         {2 , 2 , 3 , 5 },
156+         /* data=*/  {0 ,  0 ,  0 ,  0 ,  0 ,  0 ,  0 ,  0 ,  0 ,  0 ,  0 ,  0 ,  0 ,  0 ,  0 ,
157+                   15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 ,
158+                   15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 ,
159+                   30 , 30 , 30 , 30 , 30 , 30 , 30 , 30 , 30 , 30 , 30 , 30 , 30 , 30 , 30 });
160+ 
161+     //  Check that it matches the expected output.
162+     EXPECT_TENSOR_CLOSE (op_sub_out (a, b, 1.0 , out), expected);
163+     expected = tf_a.make (
164+         {2 , 2 , 3 , 5 },
165+         /* data=*/  {0 ,   0 ,   0 ,   0 ,   0 ,   0 ,   0 ,   0 ,   0 ,   0 ,   0 ,   0 ,
166+                   0 ,   0 ,   0 ,   -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 ,
167+                   -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 ,
168+                   -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 , -15 , -30 , -30 , -30 ,
169+                   -30 , -30 , -30 , -30 , -30 , -30 , -30 , -30 , -30 , -30 , -30 , -30 });
170+     EXPECT_TENSOR_CLOSE (op_sub_out (b, a, 1.0 , out), expected);
171+ 
172+     b = tf_a.make (
173+         {2 , 2 , 1 , 5 }, /* data=*/  {1 ,  2 ,  3 ,  4 ,  5 ,  6 ,  7 ,  8 ,  9 ,  10 ,
174+                                 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 , 20 });
175+     out = tf_a.zeros ({2 , 2 , 3 , 5 });
176+     expected = tf_a.make (
177+         {2 , 2 , 3 , 5 },
178+         /* data=*/  {0 ,  0 ,  0 ,  0 ,  0 ,  5 ,  5 ,  5 ,  5 ,  5 ,  10 , 10 , 10 , 10 , 10 ,
179+                   10 , 10 , 10 , 10 , 10 , 15 , 15 , 15 , 15 , 15 , 20 , 20 , 20 , 20 , 20 ,
180+                   20 , 20 , 20 , 20 , 20 , 25 , 25 , 25 , 25 , 25 , 30 , 30 , 30 , 30 , 30 ,
181+                   30 , 30 , 30 , 30 , 30 , 35 , 35 , 35 , 35 , 35 , 40 , 40 , 40 , 40 , 40 });
182+ 
183+     //  Check that it matches the expected output.
184+     EXPECT_TENSOR_CLOSE (op_sub_out (a, b, 1.0 , out), expected);
185+     expected = tf_a.make (
186+         {2 , 2 , 3 , 5 },
187+         /* data=*/  {-0.5000 ,  -1.0000 ,  -1.5000 ,  -2.0000 ,  -2.5000 ,
188+                   -8.0000 ,  -8.5000 ,  -9.0000 ,  -9.5000 ,  -10.0000 ,
189+                   -15.5000 , -16.0000 , -16.5000 , -17.0000 , -17.5000 ,
190+ 
191+                   -18.0000 , -18.5000 , -19.0000 , -19.5000 , -20.0000 ,
192+                   -25.5000 , -26.0000 , -26.5000 , -27.0000 , -27.5000 ,
193+                   -33.0000 , -33.5000 , -34.0000 , -34.5000 , -35.0000 ,
194+ 
195+                   -35.5000 , -36.0000 , -36.5000 , -37.0000 , -37.5000 ,
196+                   -43.0000 , -43.5000 , -44.0000 , -44.5000 , -45.0000 ,
197+                   -50.5000 , -51.0000 , -51.5000 , -52.0000 , -52.5000 ,
198+ 
199+                   -53.0000 , -53.5000 , -54.0000 , -54.5000 , -55.0000 ,
200+                   -60.5000 , -61.0000 , -61.5000 , -62.0000 , -62.5000 ,
201+                   -68.0000 , -68.5000 , -69.0000 , -69.5000 , -70.0000 });
202+     EXPECT_TENSOR_CLOSE (op_sub_out (b, a, 1.5 , out), expected);
203+   }
204+ 
102205  void  test_sub_enumerate_a_types () {
103206#define  ENUMERATE_TEST_ENTRY (ctype, dtype ) \
104207  test_sub_enumerate_b_types<ScalarType::dtype>();
@@ -237,6 +340,19 @@ TEST_F(OpSubOutTest, BroadcastScalarRank0Supported) {
237340  EXPECT_TENSOR_EQ (out, ret);
238341}
239342
343+ TEST_F (OpSubOutTest, BroadcastNDTest) {
344+   //  Test 3D tensors
345+   test_broadcast_3D<ScalarType::Float>();
346+   test_broadcast_3D<ScalarType::Half>();
347+   //  Sub doesnt yet support BFloat16
348+   //  test_broadcast_3D<ScalarType::BFloat16>();
349+ 
350+   //  Test 4D tensors
351+   test_broadcast_4D<ScalarType::Float>();
352+   test_broadcast_4D<ScalarType::Half>();
353+   //  test_broadcast_4D<ScalarType::BFloat16>();
354+ }
355+ 
240356// 
241357//  Death Tests
242358// 
0 commit comments