Skip to content

Commit 95e9513

Browse files
committed
Fixes for Arrow STL iterator for custom types
1 parent 8974ddc commit 95e9513

File tree

2 files changed

+345
-40
lines changed

2 files changed

+345
-40
lines changed

cpp/src/arrow/stl_iterator.h

Lines changed: 112 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "arrow/type.h"
2828
#include "arrow/type_fwd.h"
2929
#include "arrow/type_traits.h"
30+
#include "arrow/util/functional.h"
3031
#include "arrow/util/macros.h"
3132

3233
namespace arrow {
@@ -38,11 +39,32 @@ template <typename ArrayType>
3839
struct DefaultValueAccessor {
3940
using ValueType = decltype(std::declval<ArrayType>().GetView(0));
4041

41-
ValueType operator()(const ArrayType& array, int64_t index) {
42+
ValueType operator()(const ArrayType& array, int64_t index) const {
4243
return array.GetView(index);
4344
}
4445
};
4546

47+
// Helper to detect if a type has a ValueType member typedef
48+
template <typename T, typename = void>
49+
struct has_value_type : std::false_type {};
50+
51+
template <typename T>
52+
struct has_value_type<T, std::void_t<typename T::ValueType>> : std::true_type {};
53+
54+
// Wrapper for callable objects (like lambdas) that don't have ValueType
55+
template <typename Callable, typename Ret, typename ArrayType>
56+
struct CallableValueAccessor {
57+
using ValueType = Ret;
58+
59+
Callable callable;
60+
61+
explicit CallableValueAccessor(Callable c) : callable(std::move(c)) {}
62+
63+
ValueType operator()(const ArrayType& array, int64_t index) const {
64+
return callable(array, index);
65+
}
66+
};
67+
4668
} // namespace detail
4769

4870
template <typename ArrayType,
@@ -56,20 +78,22 @@ class ArrayIterator {
5678
using iterator_category = std::random_access_iterator_tag;
5779

5880
// Some algorithms need to default-construct an iterator
59-
ArrayIterator() : array_(NULLPTR), index_(0) {}
81+
ArrayIterator() : array_(NULLPTR), index_(0), value_accessor_() {}
6082

61-
explicit ArrayIterator(const ArrayType& array, int64_t index = 0)
62-
: array_(&array), index_(index) {}
83+
explicit ArrayIterator(const ArrayType& array, int64_t index = 0,
84+
ValueAccessor value_accessor = {})
85+
: array_(&array), index_(index), value_accessor_(std::move(value_accessor)) {}
6386

6487
// Value access
6588
value_type operator*() const {
6689
assert(array_);
67-
return array_->IsNull(index_) ? value_type{} : array_->GetView(index_);
90+
return array_->IsNull(index_) ? value_type{} : value_accessor_(*array_, index_);
6891
}
6992

7093
value_type operator[](difference_type n) const {
7194
assert(array_);
72-
return array_->IsNull(index_ + n) ? value_type{} : array_->GetView(index_ + n);
95+
return array_->IsNull(index_ + n) ? value_type{}
96+
: value_accessor_(*array_, index_ + n);
7397
}
7498

7599
int64_t index() const { return index_; }
@@ -99,18 +123,18 @@ class ArrayIterator {
99123
return index_ - other.index_;
100124
}
101125
ArrayIterator operator+(difference_type n) const {
102-
return ArrayIterator(*array_, index_ + n);
126+
return ArrayIterator(*array_, index_ + n, value_accessor_);
103127
}
104128
ArrayIterator operator-(difference_type n) const {
105-
return ArrayIterator(*array_, index_ - n);
129+
return ArrayIterator(*array_, index_ - n, value_accessor_);
106130
}
107131
friend inline ArrayIterator operator+(difference_type diff,
108132
const ArrayIterator& other) {
109-
return ArrayIterator(*other.array_, diff + other.index_);
133+
return ArrayIterator(*other.array_, diff + other.index_, other.value_accessor_);
110134
}
111135
friend inline ArrayIterator operator-(difference_type diff,
112136
const ArrayIterator& other) {
113-
return ArrayIterator(*other.array_, diff - other.index_);
137+
return ArrayIterator(*other.array_, diff - other.index_, other.value_accessor_);
114138
}
115139
ArrayIterator& operator+=(difference_type n) {
116140
index_ += n;
@@ -132,6 +156,7 @@ class ArrayIterator {
132156
private:
133157
const ArrayType* array_;
134158
int64_t index_;
159+
ValueAccessor value_accessor_;
135160
};
136161

137162
template <typename ArrayType,
@@ -145,18 +170,22 @@ class ChunkedArrayIterator {
145170
using iterator_category = std::random_access_iterator_tag;
146171

147172
// Some algorithms need to default-construct an iterator
148-
ChunkedArrayIterator() noexcept : chunked_array_(NULLPTR), index_(0) {}
173+
ChunkedArrayIterator() noexcept
174+
: chunked_array_(NULLPTR), index_(0), value_accessor_() {}
149175

150-
explicit ChunkedArrayIterator(const ChunkedArray& chunked_array,
151-
int64_t index = 0) noexcept
152-
: chunked_array_(&chunked_array), index_(index) {}
176+
explicit ChunkedArrayIterator(const ChunkedArray& chunked_array, int64_t index = 0,
177+
ValueAccessor value_accessor = {}) noexcept
178+
: chunked_array_(&chunked_array),
179+
index_(index),
180+
value_accessor_(std::move(value_accessor)) {}
153181

154182
// Value access
155183
value_type operator*() const {
156184
auto chunk_location = GetChunkLocation(index_);
157-
ArrayIterator<ArrayType> target_iterator{
185+
ArrayIterator<ArrayType, ValueAccessor> target_iterator{
158186
arrow::internal::checked_cast<const ArrayType&>(
159-
*chunked_array_->chunk(static_cast<int>(chunk_location.chunk_index)))};
187+
*chunked_array_->chunk(static_cast<int>(chunk_location.chunk_index))),
188+
0, value_accessor_};
160189
return target_iterator[chunk_location.index_in_chunk];
161190
}
162191

@@ -191,21 +220,23 @@ class ChunkedArrayIterator {
191220
}
192221
ChunkedArrayIterator operator+(difference_type n) const {
193222
assert(chunked_array_);
194-
return ChunkedArrayIterator(*chunked_array_, index_ + n);
223+
return ChunkedArrayIterator(*chunked_array_, index_ + n, value_accessor_);
195224
}
196225
ChunkedArrayIterator operator-(difference_type n) const {
197226
assert(chunked_array_);
198-
return ChunkedArrayIterator(*chunked_array_, index_ - n);
227+
return ChunkedArrayIterator(*chunked_array_, index_ - n, value_accessor_);
199228
}
200229
friend inline ChunkedArrayIterator operator+(difference_type diff,
201230
const ChunkedArrayIterator& other) {
202231
assert(other.chunked_array_);
203-
return ChunkedArrayIterator(*other.chunked_array_, diff + other.index_);
232+
return ChunkedArrayIterator(*other.chunked_array_, diff + other.index_,
233+
other.value_accessor_);
204234
}
205235
friend inline ChunkedArrayIterator operator-(difference_type diff,
206236
const ChunkedArrayIterator& other) {
207237
assert(other.chunked_array_);
208-
return ChunkedArrayIterator(*other.chunked_array_, diff - other.index_);
238+
return ChunkedArrayIterator(*other.chunked_array_, diff - other.index_,
239+
other.value_accessor_);
209240
}
210241
ChunkedArrayIterator& operator+=(difference_type n) {
211242
index_ += n;
@@ -244,56 +275,97 @@ class ChunkedArrayIterator {
244275

245276
const ChunkedArray* chunked_array_;
246277
int64_t index_;
278+
ValueAccessor value_accessor_;
247279
};
248280

249281
/// Return an iterator to the beginning of the chunked array
250-
template <typename Type, typename ArrayType = typename TypeTraits<Type>::ArrayType>
251-
ChunkedArrayIterator<ArrayType> Begin(const ChunkedArray& chunked_array) {
252-
return ChunkedArrayIterator<ArrayType>(chunked_array);
282+
template <typename Type, typename ArrayType = typename TypeTraits<Type>::ArrayType,
283+
typename ValueAccessor = detail::DefaultValueAccessor<ArrayType>>
284+
ChunkedArrayIterator<ArrayType, ValueAccessor> Begin(const ChunkedArray& chunked_array,
285+
ValueAccessor value_accessor = {}) {
286+
return ChunkedArrayIterator<ArrayType, ValueAccessor>(chunked_array, 0,
287+
std::move(value_accessor));
253288
}
254289

255290
/// Return an iterator to the end of the chunked array
256-
template <typename Type, typename ArrayType = typename TypeTraits<Type>::ArrayType>
257-
ChunkedArrayIterator<ArrayType> End(const ChunkedArray& chunked_array) {
258-
return ChunkedArrayIterator<ArrayType>(chunked_array, chunked_array.length());
291+
template <typename Type, typename ArrayType = typename TypeTraits<Type>::ArrayType,
292+
typename ValueAccessor = detail::DefaultValueAccessor<ArrayType>>
293+
ChunkedArrayIterator<ArrayType, ValueAccessor> End(const ChunkedArray& chunked_array,
294+
ValueAccessor value_accessor = {}) {
295+
return ChunkedArrayIterator<ArrayType, ValueAccessor>(
296+
chunked_array, chunked_array.length(), std::move(value_accessor));
259297
}
260298

261-
template <typename ArrayType>
299+
template <typename ArrayType,
300+
typename ValueAccessor = detail::DefaultValueAccessor<ArrayType>>
262301
struct ChunkedArrayRange {
263302
const ChunkedArray* chunked_array;
303+
ValueAccessor value_accessor;
264304

265-
ChunkedArrayIterator<ArrayType> begin() {
266-
return stl::ChunkedArrayIterator<ArrayType>(*chunked_array);
305+
ChunkedArrayIterator<ArrayType, ValueAccessor> begin() {
306+
return stl::ChunkedArrayIterator<ArrayType, ValueAccessor>(*chunked_array, 0,
307+
value_accessor);
267308
}
268-
ChunkedArrayIterator<ArrayType> end() {
269-
return stl::ChunkedArrayIterator<ArrayType>(*chunked_array, chunked_array->length());
309+
ChunkedArrayIterator<ArrayType, ValueAccessor> end() {
310+
return stl::ChunkedArrayIterator<ArrayType, ValueAccessor>(
311+
*chunked_array, chunked_array->length(), value_accessor);
270312
}
271313
};
272314

273315
/// Return an iterable range over the chunked array
274-
template <typename Type, typename ArrayType = typename TypeTraits<Type>::ArrayType>
275-
ChunkedArrayRange<ArrayType> Iterate(const ChunkedArray& chunked_array) {
276-
return stl::ChunkedArrayRange<ArrayType>{&chunked_array};
316+
template <typename Type, typename ArrayType = typename TypeTraits<Type>::ArrayType,
317+
typename ValueAccessor = detail::DefaultValueAccessor<ArrayType>>
318+
ChunkedArrayRange<ArrayType, ValueAccessor> Iterate(const ChunkedArray& chunked_array,
319+
ValueAccessor value_accessor = {}) {
320+
return stl::ChunkedArrayRange<ArrayType, ValueAccessor>{&chunked_array,
321+
std::move(value_accessor)};
322+
}
323+
324+
/// Return an iterable range over the chunked array with a custom value accessor
325+
/// This overload deduces ArrayType from the ValueAccessor's first parameter type
326+
/// and requires that ValueAccessor has a ValueType typedef (i.e., it's a struct)
327+
template <typename ValueAccessor,
328+
typename = internal::call_traits::disable_if_overloaded<ValueAccessor>,
329+
typename = std::enable_if_t<detail::has_value_type<ValueAccessor>::value>>
330+
auto Iterate(const ChunkedArray& chunked_array, ValueAccessor value_accessor) {
331+
using ArrayType = std::decay_t<internal::call_traits::argument_type<0, ValueAccessor>>;
332+
return stl::ChunkedArrayRange<ArrayType, ValueAccessor>{&chunked_array,
333+
std::move(value_accessor)};
334+
}
335+
336+
/// Return an iterable range over the chunked array with a callable (e.g., lambda)
337+
/// This overload wraps callables that don't have a ValueType typedef
338+
template <typename Callable,
339+
typename = internal::call_traits::disable_if_overloaded<Callable>,
340+
typename = std::enable_if_t<!detail::has_value_type<Callable>::value>,
341+
typename = void>
342+
auto Iterate(const ChunkedArray& chunked_array, Callable callable) {
343+
using ArrayType = std::decay_t<internal::call_traits::argument_type<0, Callable>>;
344+
using ReturnType = std::decay_t<internal::call_traits::return_type<Callable>>;
345+
using WrappedAccessor = detail::CallableValueAccessor<Callable, ReturnType, ArrayType>;
346+
347+
return stl::ChunkedArrayRange<ArrayType, WrappedAccessor>{
348+
&chunked_array, WrappedAccessor{std::move(callable)}};
277349
}
278350

279351
} // namespace stl
280352
} // namespace arrow
281353

282354
namespace std {
283355

284-
template <typename ArrayType>
285-
struct iterator_traits<::arrow::stl::ArrayIterator<ArrayType>> {
286-
using IteratorType = ::arrow::stl::ArrayIterator<ArrayType>;
356+
template <typename ArrayType, typename ValueAccessor>
357+
struct iterator_traits<::arrow::stl::ArrayIterator<ArrayType, ValueAccessor>> {
358+
using IteratorType = ::arrow::stl::ArrayIterator<ArrayType, ValueAccessor>;
287359
using difference_type = typename IteratorType::difference_type;
288360
using value_type = typename IteratorType::value_type;
289361
using pointer = typename IteratorType::pointer;
290362
using reference = typename IteratorType::reference;
291363
using iterator_category = typename IteratorType::iterator_category;
292364
};
293365

294-
template <typename ArrayType>
295-
struct iterator_traits<::arrow::stl::ChunkedArrayIterator<ArrayType>> {
296-
using IteratorType = ::arrow::stl::ChunkedArrayIterator<ArrayType>;
366+
template <typename ArrayType, typename ValueAccessor>
367+
struct iterator_traits<::arrow::stl::ChunkedArrayIterator<ArrayType, ValueAccessor>> {
368+
using IteratorType = ::arrow::stl::ChunkedArrayIterator<ArrayType, ValueAccessor>;
297369
using difference_type = typename IteratorType::difference_type;
298370
using value_type = typename IteratorType::value_type;
299371
using pointer = typename IteratorType::pointer;

0 commit comments

Comments
 (0)