@@ -112,8 +112,66 @@ TEST(BroadcastIndexesRangeTest, OneAndTwoDExhaustive) {
112112 EXPECT_EQ (expected, actual);
113113}
114114
115- // Here we assume that the previous tests established that padding
116- // with leading 1s is working, and test:
115+ // Make sure nothing is thrown off by a size-1 dim in the output:
116+ // [] -> [1, W]
117+ // [] -> [H, 1]
118+ // [1] -> [1, W]
119+ // [1] -> [H, 1]
120+ // [W] -> [1, W]
121+ // [1, 1] -> [1, W]
122+ // [1, 1] -> [H, 1]
123+ // [1, W] -> [1, W]
124+ // [H, 1] -> [H, 1]
125+ TEST (BroadcastIndexesRangeTest, OneAndTwoDWith1InOutputShapeExhaustive) {
126+ TensorFactory<ScalarType::Int> tf;
127+ constexpr auto H = 2 ;
128+ constexpr auto W = 3 ;
129+ Tensor out_row = tf.zeros ({1 , W});
130+ Tensor out_col = tf.zeros ({H, 1 });
131+ Tensor in_0d_scalar = tf.zeros ({});
132+ Tensor in_1d_scalar = tf.zeros ({1 });
133+ Tensor in_2d_scalar = tf.zeros ({1 , 1 });
134+
135+ Tensor in_row = tf.zeros ({W});
136+ Tensor in_leading_one_row = tf.zeros ({1 , W});
137+
138+ Tensor in_col = tf.zeros ({H, 1 });
139+
140+ size_t idx = 0 ;
141+ for (const auto
142+ [out_idx,
143+ in_0d_idx,
144+ in_1d_idx,
145+ in_2d_idx,
146+ in_row_idx,
147+ in_leading_one_row_idx] :
148+ BroadcastIndexesRange<5 >(
149+ out_row,
150+ in_0d_scalar,
151+ in_1d_scalar,
152+ in_2d_scalar,
153+ in_row,
154+ in_leading_one_row)) {
155+ EXPECT_EQ (out_idx, idx++);
156+ EXPECT_EQ (in_0d_idx, 0 );
157+ EXPECT_EQ (in_1d_idx, 0 );
158+ EXPECT_EQ (in_2d_idx, 0 );
159+ EXPECT_EQ (in_row_idx, out_idx);
160+ EXPECT_EQ (in_leading_one_row_idx, out_idx);
161+ }
162+
163+ idx = 0 ;
164+ for (const auto [out_idx, in_0d_idx, in_1d_idx, in_2d_idx, in_col_idx] :
165+ BroadcastIndexesRange<4 >(
166+ out_col, in_0d_scalar, in_1d_scalar, in_2d_scalar, in_col)) {
167+ EXPECT_EQ (out_idx, idx++);
168+ EXPECT_EQ (in_0d_idx, 0 );
169+ EXPECT_EQ (in_1d_idx, 0 );
170+ EXPECT_EQ (in_2d_idx, 0 );
171+ EXPECT_EQ (in_col_idx, out_idx);
172+ }
173+ }
174+
117175// [1, 1, 1] -> [C, H, W]
118176// [C, H, 1] -> [C, H, W]
119177// [C, 1, W] -> [C, H, W]
@@ -166,11 +224,12 @@ TEST(BroadcastIndexesRangeTest, ThreeDBroadcasting) {
166224// 4-D should generalize, but we will go ahead and test:
167225// [N, 1, H, 1] -> [N, C, H, W]
168226// [1, C, 1, W] -> [N, C, H, W]
169- TEST (BroadcastIndexesRangeTest, FourDBroadcasting) {
227+ template <size_t N, size_t C, size_t H, size_t W>
228+ void four_d_broadcasting_test () {
170229 TensorFactory<ScalarType::Int> tf;
171- Tensor out = tf.zeros ({2 , 3 , 4 , 5 });
172- Tensor in_broadcast_cw = tf.zeros ({2 , 1 , 4 , 1 });
173- Tensor in_broadcast_nh = tf.zeros ({1 , 3 , 1 , 5 });
230+ Tensor out = tf.zeros ({N, C, H, W });
231+ Tensor in_broadcast_cw = tf.zeros ({N , 1 , H , 1 });
232+ Tensor in_broadcast_nh = tf.zeros ({1 , C , 1 , W });
174233
175234 // Writing out all the indexes would be too cumbersome, so here we
176235 // take the opportunity to mutation test against delinearize_index
@@ -190,3 +249,12 @@ TEST(BroadcastIndexesRangeTest, FourDBroadcasting) {
190249 linearize_access_indexes (out_indexes, out.dim (), in_broadcast_nh));
191250 }
192251}
252+
253+ TEST (BroadcastIndexesRangeTest, FourDBroadcasting) {
254+ four_d_broadcasting_test<2 , 3 , 4 , 5 >();
255+ }
256+
257+ TEST (BroadcastIndexesRangeTest, FourDBroadcastingWithOneDimsInOutput) {
258+ four_d_broadcasting_test<2 , 3 , 1 , 5 >();
259+ four_d_broadcasting_test<2 , 1 , 3 , 1 >();
260+ }
0 commit comments