Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions kernels/portable/cpu/util/broadcast_indexes_range.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,87 @@ class BroadcastIndexesIterator {
// shape would contain 1s.
std::array<ShapeType, kNumInputs> effective_input_broadcast_strides_;
};

// When there is only 1 input and no noncontiguous tensor support
// required, there is no actual broadcasting to do.
template <>
class BroadcastIndexesIterator<1, false> {
public:
using difference_type = ssize_t;
using value_type = std::array<ssize_t, 2>;
using reference = value_type;
using pointer = const value_type*;
using iterator_category = std::forward_iterator_tag;

BroadcastIndexesIterator() = default;

explicit BroadcastIndexesIterator(
[[maybe_unused]] const Tensor& output,
[[maybe_unused]] const Tensor& input) {}

struct make_end_t {
explicit constexpr make_end_t() = default;
};

BroadcastIndexesIterator(
make_end_t,
const Tensor& output,
[[maybe_unused]] const Tensor& input)
: current_indexes_({output.numel(), output.numel()}) {}

bool operator==(const BroadcastIndexesIterator& rhs) const {
return current_index() == rhs.current_index();
}

bool operator!=(const BroadcastIndexesIterator& rhs) const {
return current_index() != rhs.current_index();
}

reference operator*() const {
return current_indexes_;
}

pointer operator->() const {
return &current_indexes_;
}

BroadcastIndexesIterator& operator++() {
add_to_current_index(1);
return *this;
}

BroadcastIndexesIterator operator++(int) {
auto it = *this;
operator++();
return it;
}

BroadcastIndexesIterator& operator+=(difference_type n) {
add_to_current_index(n);
return *this;
}

BroadcastIndexesIterator operator+(difference_type n) {
auto it = *this;
it += n;
return it;
}

difference_type operator-(const BroadcastIndexesIterator& rhs) const {
return difference_type(current_index() - rhs.current_index());
}

private:
ssize_t current_index() const {
return current_indexes_[0];
}

void add_to_current_index(ssize_t n) {
current_indexes_[0] += n;
current_indexes_[1] = current_indexes_[0];
}
value_type current_indexes_ = {{0, 0}};
};
} // namespace internal

/**
Expand Down
18 changes: 0 additions & 18 deletions kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,24 +52,6 @@ TEST(BroadcastIndexesRangeTest, OneDNotBroadcasted) {
}
}

// [1] -> [W]
TEST(BroadcastIndexesRangeTest, ScalarBroadcastToOneD) {
TensorFactory<ScalarType::Int> tf;

Tensor out = tf.zeros({5});
Tensor in = tf.zeros({1});

auto actual = range_to_vec(BroadcastIndexesRange<1>(out, in));
decltype(actual) expected = {
{0, 0},
{1, 0},
{2, 0},
{3, 0},
{4, 0},
};
EXPECT_EQ(expected, actual);
}

template <typename Range>
void test_operator_plus(const Range& range) {
size_t idx = 0;
Expand Down
Loading