Skip to content

Commit 7e13bcd

Browse files
swolchokBujSet
authored andcommitted
Specialize BroadcastIndexesRange for the case where there is only 1 contiguous input (pytorch#12023)
In this case, broadcasting is not possible if I understand correctly. NOTE TO REVIEWERS: I deleted a failing test because I think it's testing not-actually-existent-in-PyTorch functionality. Please let me know if I've made a mistake. I tried to exercise the behavior that this test implied existed like so: ``` >>> t = torch.tensor([1, 2, 3]) >>> t2 = torch.tensor(4) >>> torch.abs(t2, out=t) <stdin>:1: UserWarning: An output with one or more elements was resized since it had shape [3], which does not match the required output shape []. This behavior is deprecated, and in a future PyTorch release outputs will not be resized unless they have zero elements. You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0). (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/Resize.cpp:38.) tensor(4) ``` I think that if the test was correct, the result would have been torch.tensor([1, 2, 3]) with no message. Also, none of our operator tests seem to be failing. Have I missed anything?
1 parent 5794fff commit 7e13bcd

File tree

2 files changed

+81
-18
lines changed

2 files changed

+81
-18
lines changed

kernels/portable/cpu/util/broadcast_indexes_range.h

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,87 @@ class BroadcastIndexesIterator {
236236
// shape would contain 1s.
237237
std::array<ShapeType, kNumInputs> effective_input_broadcast_strides_;
238238
};
239+
240+
// When there is only 1 input and no noncontiguous tensor support
241+
// required, there is no actual broadcasting to do.
242+
template <>
243+
class BroadcastIndexesIterator<1, false> {
244+
public:
245+
using difference_type = ssize_t;
246+
using value_type = std::array<ssize_t, 2>;
247+
using reference = value_type;
248+
using pointer = const value_type*;
249+
using iterator_category = std::forward_iterator_tag;
250+
251+
BroadcastIndexesIterator() = default;
252+
253+
explicit BroadcastIndexesIterator(
254+
[[maybe_unused]] const Tensor& output,
255+
[[maybe_unused]] const Tensor& input) {}
256+
257+
struct make_end_t {
258+
explicit constexpr make_end_t() = default;
259+
};
260+
261+
BroadcastIndexesIterator(
262+
make_end_t,
263+
const Tensor& output,
264+
[[maybe_unused]] const Tensor& input)
265+
: current_indexes_({output.numel(), output.numel()}) {}
266+
267+
bool operator==(const BroadcastIndexesIterator& rhs) const {
268+
return current_index() == rhs.current_index();
269+
}
270+
271+
bool operator!=(const BroadcastIndexesIterator& rhs) const {
272+
return current_index() != rhs.current_index();
273+
}
274+
275+
reference operator*() const {
276+
return current_indexes_;
277+
}
278+
279+
pointer operator->() const {
280+
return &current_indexes_;
281+
}
282+
283+
BroadcastIndexesIterator& operator++() {
284+
add_to_current_index(1);
285+
return *this;
286+
}
287+
288+
BroadcastIndexesIterator operator++(int) {
289+
auto it = *this;
290+
operator++();
291+
return it;
292+
}
293+
294+
BroadcastIndexesIterator& operator+=(difference_type n) {
295+
add_to_current_index(n);
296+
return *this;
297+
}
298+
299+
BroadcastIndexesIterator operator+(difference_type n) {
300+
auto it = *this;
301+
it += n;
302+
return it;
303+
}
304+
305+
difference_type operator-(const BroadcastIndexesIterator& rhs) const {
306+
return difference_type(current_index() - rhs.current_index());
307+
}
308+
309+
private:
310+
ssize_t current_index() const {
311+
return current_indexes_[0];
312+
}
313+
314+
void add_to_current_index(ssize_t n) {
315+
current_indexes_[0] += n;
316+
current_indexes_[1] = current_indexes_[0];
317+
}
318+
value_type current_indexes_ = {{0, 0}};
319+
};
239320
} // namespace internal
240321

241322
/**

kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -52,24 +52,6 @@ TEST(BroadcastIndexesRangeTest, OneDNotBroadcasted) {
5252
}
5353
}
5454

55-
// [1] -> [W]
56-
TEST(BroadcastIndexesRangeTest, ScalarBroadcastToOneD) {
57-
TensorFactory<ScalarType::Int> tf;
58-
59-
Tensor out = tf.zeros({5});
60-
Tensor in = tf.zeros({1});
61-
62-
auto actual = range_to_vec(BroadcastIndexesRange<1>(out, in));
63-
decltype(actual) expected = {
64-
{0, 0},
65-
{1, 0},
66-
{2, 0},
67-
{3, 0},
68-
{4, 0},
69-
};
70-
EXPECT_EQ(expected, actual);
71-
}
72-
7355
template <typename Range>
7456
void test_operator_plus(const Range& range) {
7557
size_t idx = 0;

0 commit comments

Comments
 (0)