Skip to content

Commit 5b88cf4

Browse files
committed
Code cleanup based on code review.
1 parent aa3f7b6 commit 5b88cf4

File tree

3 files changed

+60
-26
lines changed

3 files changed

+60
-26
lines changed

src/pdal/PyArray.cpp

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -172,56 +172,64 @@ std::shared_ptr<ArrayIter> Array::iterator()
172172
ArrayIter::ArrayIter(PyArrayObject* np_array, std::shared_ptr<ArrayStreamHandler> stream_handler)
173173
: m_stream_handler(std::move(stream_handler))
174174
{
175-
resetIterator(np_array);
175+
// Create iterator
176+
m_iter = NpyIter_New(np_array,
177+
NPY_ITER_EXTERNAL_LOOP | NPY_ITER_READONLY | NPY_ITER_REFS_OK,
178+
NPY_KEEPORDER, NPY_NO_CASTING, NULL);
179+
if (!m_iter)
180+
throw pdal_error("Unable to create numpy iterator.");
181+
182+
initIterator();
176183
}
177184

178-
void ArrayIter::resetIterator(std::optional<PyArrayObject*> np_array = {})
185+
void ArrayIter::initIterator()
179186
{
180-
std::optional<int> stream_chunk_size = std::nullopt;
187+
// For a stream handler, first execute it to get the buffer populated and know the size of the data to iterate
188+
int64_t stream_chunk_size = 0;
181189
if (m_stream_handler) {
182190
stream_chunk_size = (*m_stream_handler)();
183-
if (*stream_chunk_size == 0) {
191+
if (!stream_chunk_size) {
184192
m_done = true;
185193
return;
186194
}
187195
}
188196

189-
if (np_array) {
190-
// Init iterator
191-
m_iter = NpyIter_New(np_array.value(),
192-
NPY_ITER_EXTERNAL_LOOP | NPY_ITER_READONLY | NPY_ITER_REFS_OK,
193-
NPY_KEEPORDER, NPY_NO_CASTING, NULL);
194-
if (!m_iter)
195-
throw pdal_error("Unable to create numpy iterator.");
196-
} else {
197-
// Otherwise, reset the iterator to the initial state
198-
if (NpyIter_Reset(m_iter, NULL) != NPY_SUCCEED) {
199-
NpyIter_Deallocate(m_iter);
200-
throw pdal_error("Unable to reset numpy iterator.");
201-
}
202-
}
203-
197+
// Initialize the iterator function
204198
char *itererr;
205199
m_iterNext = NpyIter_GetIterNext(m_iter, &itererr);
206200
if (!m_iterNext)
207201
{
208202
NpyIter_Deallocate(m_iter);
209-
throw pdal_error(std::string("Unable to create numpy iterator: ") + itererr);
203+
m_iter = nullptr;
204+
throw pdal_error(std::string("Unable to retrieve iteration function from numpy iterator: ") + itererr);
210205
}
211206
m_data = NpyIter_GetDataPtrArray(m_iter);
212207
m_stride = *NpyIter_GetInnerStrideArray(m_iter);
213208
m_size = *NpyIter_GetInnerLoopSizePtr(m_iter);
214209
if (stream_chunk_size) {
215-
if (0 <= *stream_chunk_size && *stream_chunk_size <= m_size) {
216-
m_size = *stream_chunk_size;
210+
// Ensure chunk size is valid and then limit iteration accordingly
211+
if (0 < stream_chunk_size && stream_chunk_size <= m_size) {
212+
m_size = stream_chunk_size;
217213
} else {
218214
throw pdal_error(std::string("Stream chunk size not in the range of array length: ") +
219-
std::to_string(*stream_chunk_size));
215+
std::to_string(stream_chunk_size));
220216
}
221217
}
222218
m_done = false;
223219
}
224220

221+
void ArrayIter::resetIterator()
222+
{
223+
// Reset the iterator to the initial state
224+
if (NpyIter_Reset(m_iter, NULL) != NPY_SUCCEED) {
225+
NpyIter_Deallocate(m_iter);
226+
m_iter = nullptr;
227+
throw pdal_error("Unable to reset numpy iterator.");
228+
}
229+
230+
initIterator();
231+
}
232+
225233
ArrayIter::~ArrayIter()
226234
{
227235
if (m_iter != nullptr) {

src/pdal/PyArray.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949

5050
#include <vector>
5151
#include <memory>
52-
#include <optional>
5352

5453
namespace pdal
5554
{
@@ -112,7 +111,8 @@ class PDAL_DLL ArrayIter
112111
bool m_done;
113112

114113
std::shared_ptr<ArrayStreamHandler> m_stream_handler;
115-
void resetIterator(std::optional<PyArrayObject*> np_array);
114+
void initIterator();
115+
void resetIterator();
116116
};
117117

118118
} // namespace python

test/test_pipeline.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -849,4 +849,30 @@ def test_pipeline_run_backward_compat(self, in_arrays, use_setter: bool):
849849
assert len(out_arrays) == len(in_arrays)
850850

851851
for in_array, out_array in zip(in_arrays, out_arrays):
852-
np.testing.assert_array_equal(out_array, in_array)
852+
np.testing.assert_array_equal(out_array, in_array)
853+
854+
@pytest.mark.parametrize("in_array, invalid_chunk_size", [
855+
(in_array, invalid_chunk_size) for in_array, invalid_chunk_size in product(
856+
[gen_chunk(1234)],
857+
[-1, 12345])
858+
])
859+
def test_pipeline_fail_with_invalid_chunk_size(self, in_array, invalid_chunk_size):
860+
"""
861+
Ensure execution fails when using an invalid stream handler:
862+
- One that returns a negative chunk size
863+
- One that returns a chunk size bigger than the buffer capacity
864+
"""
865+
was_called = False
866+
def invalid_stream_handler():
867+
nonlocal was_called
868+
if was_called:
869+
# avoid infinite loop
870+
raise ValueError("Invalid handler should not have been called a second time")
871+
was_called = True
872+
return invalid_chunk_size
873+
874+
p = pdal.Pipeline(arrays=[in_array], stream_handlers=[invalid_stream_handler])
875+
with pytest.raises(RuntimeError,
876+
match=f"Stream chunk size not in the range of array length: {invalid_chunk_size}"):
877+
p.execute()
878+

0 commit comments

Comments
 (0)