@@ -236,6 +236,87 @@ class BroadcastIndexesIterator {
236236 // shape would contain 1s.
237237 std::array<ShapeType, kNumInputs > effective_input_broadcast_strides_;
238238};
239+
240+ // When there is only 1 input and no noncontiguous tensor support
241+ // required, there is no actual broadcasting to do.
242+ template <>
243+ class BroadcastIndexesIterator <1 , false > {
244+ public:
245+ using difference_type = ssize_t ;
246+ using value_type = std::array<ssize_t , 2 >;
247+ using reference = value_type;
248+ using pointer = const value_type*;
249+ using iterator_category = std::forward_iterator_tag;
250+
251+ BroadcastIndexesIterator () = default ;
252+
253+ explicit BroadcastIndexesIterator (
254+ [[maybe_unused]] const Tensor& output,
255+ [[maybe_unused]] const Tensor& input) {}
256+
257+ struct make_end_t {
258+ explicit constexpr make_end_t () = default;
259+ };
260+
261+ BroadcastIndexesIterator (
262+ make_end_t ,
263+ const Tensor& output,
264+ [[maybe_unused]] const Tensor& input)
265+ : current_indexes_({output.numel (), output.numel ()}) {}
266+
267+ bool operator ==(const BroadcastIndexesIterator& rhs) const {
268+ return current_index () == rhs.current_index ();
269+ }
270+
271+ bool operator !=(const BroadcastIndexesIterator& rhs) const {
272+ return current_index () != rhs.current_index ();
273+ }
274+
275+ reference operator *() const {
276+ return current_indexes_;
277+ }
278+
279+ pointer operator ->() const {
280+ return ¤t_indexes_;
281+ }
282+
283+ BroadcastIndexesIterator& operator ++() {
284+ add_to_current_index (1 );
285+ return *this ;
286+ }
287+
288+ BroadcastIndexesIterator operator ++(int ) {
289+ auto it = *this ;
290+ operator ++();
291+ return it;
292+ }
293+
294+ BroadcastIndexesIterator& operator +=(difference_type n) {
295+ add_to_current_index (n);
296+ return *this ;
297+ }
298+
299+ BroadcastIndexesIterator operator +(difference_type n) {
300+ auto it = *this ;
301+ it += n;
302+ return it;
303+ }
304+
305+ difference_type operator -(const BroadcastIndexesIterator& rhs) const {
306+ return difference_type (current_index () - rhs.current_index ());
307+ }
308+
309+ private:
310+ ssize_t current_index () const {
311+ return current_indexes_[0 ];
312+ }
313+
314+ void add_to_current_index (ssize_t n) {
315+ current_indexes_[0 ] += n;
316+ current_indexes_[1 ] = current_indexes_[0 ];
317+ }
318+ value_type current_indexes_ = {{0 , 0 }};
319+ };
239320} // namespace internal
240321
241322/* *
0 commit comments