diff --git a/kernels/portable/cpu/util/broadcast_indexes_range.h b/kernels/portable/cpu/util/broadcast_indexes_range.h index 7434748d505..ae9970df653 100644 --- a/kernels/portable/cpu/util/broadcast_indexes_range.h +++ b/kernels/portable/cpu/util/broadcast_indexes_range.h @@ -236,6 +236,87 @@ class BroadcastIndexesIterator { // shape would contain 1s. std::array effective_input_broadcast_strides_; }; + +// When there is only 1 input and no noncontiguous tensor support +// required, there is no actual broadcasting to do. +template <> +class BroadcastIndexesIterator<1, false> { + public: + using difference_type = ssize_t; + using value_type = std::array; + using reference = value_type; + using pointer = const value_type*; + using iterator_category = std::forward_iterator_tag; + + BroadcastIndexesIterator() = default; + + explicit BroadcastIndexesIterator( + [[maybe_unused]] const Tensor& output, + [[maybe_unused]] const Tensor& input) {} + + struct make_end_t { + explicit constexpr make_end_t() = default; + }; + + BroadcastIndexesIterator( + make_end_t, + const Tensor& output, + [[maybe_unused]] const Tensor& input) + : current_indexes_({output.numel(), output.numel()}) {} + + bool operator==(const BroadcastIndexesIterator& rhs) const { + return current_index() == rhs.current_index(); + } + + bool operator!=(const BroadcastIndexesIterator& rhs) const { + return current_index() != rhs.current_index(); + } + + reference operator*() const { + return current_indexes_; + } + + pointer operator->() const { + return ¤t_indexes_; + } + + BroadcastIndexesIterator& operator++() { + add_to_current_index(1); + return *this; + } + + BroadcastIndexesIterator operator++(int) { + auto it = *this; + operator++(); + return it; + } + + BroadcastIndexesIterator& operator+=(difference_type n) { + add_to_current_index(n); + return *this; + } + + BroadcastIndexesIterator operator+(difference_type n) { + auto it = *this; + it += n; + return it; + } + + difference_type operator-(const BroadcastIndexesIterator& rhs) const { + return difference_type(current_index() - rhs.current_index()); + } + + private: + ssize_t current_index() const { + return current_indexes_[0]; + } + + void add_to_current_index(ssize_t n) { + current_indexes_[0] += n; + current_indexes_[1] = current_indexes_[0]; + } + value_type current_indexes_ = {{0, 0}}; +}; } // namespace internal /** diff --git a/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp b/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp index 1023915ea66..42fd2484cf0 100644 --- a/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp +++ b/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp @@ -52,24 +52,6 @@ TEST(BroadcastIndexesRangeTest, OneDNotBroadcasted) { } } -// [1] -> [W] -TEST(BroadcastIndexesRangeTest, ScalarBroadcastToOneD) { - TensorFactory tf; - - Tensor out = tf.zeros({5}); - Tensor in = tf.zeros({1}); - - auto actual = range_to_vec(BroadcastIndexesRange<1>(out, in)); - decltype(actual) expected = { - {0, 0}, - {1, 0}, - {2, 0}, - {3, 0}, - {4, 0}, - }; - EXPECT_EQ(expected, actual); -} - template void test_operator_plus(const Range& range) { size_t idx = 0;