@@ -236,6 +236,87 @@ class BroadcastIndexesIterator {
236
236
// shape would contain 1s.
237
237
std::array<ShapeType, kNumInputs > effective_input_broadcast_strides_;
238
238
};
239
+
240
+ // When there is only 1 input and no noncontiguous tensor support
241
+ // required, there is no actual broadcasting to do.
242
+ template <>
243
+ class BroadcastIndexesIterator <1 , false > {
244
+ public:
245
+ using difference_type = ssize_t ;
246
+ using value_type = std::array<ssize_t , 2 >;
247
+ using reference = value_type;
248
+ using pointer = const value_type*;
249
+ using iterator_category = std::forward_iterator_tag;
250
+
251
+ BroadcastIndexesIterator () = default ;
252
+
253
+ explicit BroadcastIndexesIterator (
254
+ [[maybe_unused]] const Tensor& output,
255
+ [[maybe_unused]] const Tensor& input) {}
256
+
257
+ struct make_end_t {
258
+ explicit constexpr make_end_t () = default;
259
+ };
260
+
261
+ BroadcastIndexesIterator (
262
+ make_end_t ,
263
+ const Tensor& output,
264
+ [[maybe_unused]] const Tensor& input)
265
+ : current_indexes_({output.numel (), output.numel ()}) {}
266
+
267
+ bool operator ==(const BroadcastIndexesIterator& rhs) const {
268
+ return current_index () == rhs.current_index ();
269
+ }
270
+
271
+ bool operator !=(const BroadcastIndexesIterator& rhs) const {
272
+ return current_index () != rhs.current_index ();
273
+ }
274
+
275
+ reference operator *() const {
276
+ return current_indexes_;
277
+ }
278
+
279
+ pointer operator ->() const {
280
+ return ¤t_indexes_;
281
+ }
282
+
283
+ BroadcastIndexesIterator& operator ++() {
284
+ add_to_current_index (1 );
285
+ return *this ;
286
+ }
287
+
288
+ BroadcastIndexesIterator operator ++(int ) {
289
+ auto it = *this ;
290
+ operator ++();
291
+ return it;
292
+ }
293
+
294
+ BroadcastIndexesIterator& operator +=(difference_type n) {
295
+ add_to_current_index (n);
296
+ return *this ;
297
+ }
298
+
299
+ BroadcastIndexesIterator operator +(difference_type n) {
300
+ auto it = *this ;
301
+ it += n;
302
+ return it;
303
+ }
304
+
305
+ difference_type operator -(const BroadcastIndexesIterator& rhs) const {
306
+ return difference_type (current_index () - rhs.current_index ());
307
+ }
308
+
309
+ private:
310
+ ssize_t current_index () const {
311
+ return current_indexes_[0 ];
312
+ }
313
+
314
+ void add_to_current_index (ssize_t n) {
315
+ current_indexes_[0 ] += n;
316
+ current_indexes_[1 ] = current_indexes_[0 ];
317
+ }
318
+ value_type current_indexes_ = {{0 , 0 }};
319
+ };
239
320
} // namespace internal
240
321
241
322
/* *
0 commit comments