Skip to content

Commit 39144c8

Browse files
kipawaajoeycarterdime10
authored
Issue #1621: Fix DataView iterator segfault with axis of size 0 (#2164)
**Context:** The `DataView` iterators suffer an underflow of unsigned types that can lead to a segfault when presented an axis of size 0, as discussed in #1621 and [here](#1598 (comment)). **Description of the Change:** This change removes the `var - 1` type arithmetic that lead to the underflows, which in turn prevents the segfaults, by adjusting the indexing done in the for loop. **Benefits:** No more underflow errors and segfaults, accurate `view.size()` and `std::distance(view.begin(), view.end())` results in the case of an axis of size 0. **Possible Drawbacks:** N/A **Related GitHub Issues:** Closes #1621 [sc-88502] --------- Co-authored-by: Joey Carter <[email protected]> Co-authored-by: David Ittah <[email protected]>
1 parent 7b02a09 commit 39144c8

File tree

4 files changed

+220
-21
lines changed

4 files changed

+220
-21
lines changed

doc/releases/changelog-dev.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,12 @@
4242
[(#1984)](https://github.com/PennyLaneAI/catalyst/pull/1984)
4343

4444
* Split `from_plxpr.py` into two files.
45-
[(#2142)](https://github.com/PennyLaneAI/catalyst/pull/2142)
45+
[(#2142)](https://github.com/PennyLaneAI/catalyst/pull/2142)
46+
47+
* Re-work `DataView` to avoid an axis of size 0 possibly triggering a segfault via an underflow
48+
error, as discovered in
49+
[this comment](https://github.com/PennyLaneAI/catalyst/pull/1598#issuecomment-2779178046).
50+
[(#1621)](https://github.com/PennyLaneAI/catalyst/pull/2164)
4651

4752
<h3>Documentation 📝</h3>
4853

@@ -58,5 +63,6 @@ This release contains contributions from (in alphabetical order):
5863

5964
Ali Asadi,
6065
Christina Lee,
66+
River McCubbin,
6167
Roberto Turrado,
6268
Paul Haochen Wang.

runtime/include/DataView.hpp

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#pragma once
1616

17+
#include <vector>
18+
1719
#include <Exception.hpp>
1820

1921
/**
@@ -56,39 +58,35 @@ template <typename T, size_t R> class DataView {
5658
iterator &operator++()
5759
{
5860
int64_t next_axis = -1;
59-
int64_t idx;
60-
for (int64_t i = R; i > 0; --i) {
61-
idx = i - 1;
62-
RT_ASSERT(view.sizes[idx] > 0);
63-
if (indices[idx]++ < view.sizes[idx] - 1) {
64-
next_axis = idx;
61+
for (int64_t axis = R - 1; axis >= 0; axis--) {
62+
if (++indices[axis] < view.sizes[axis]) {
63+
next_axis = axis;
6564
break;
6665
}
67-
indices[idx] = 0;
68-
loc -= (view.sizes[idx] - 1) * view.strides[idx];
66+
67+
indices[axis] = 0;
68+
69+
loc -= view.sizes[axis] == 0 ? 0 : (view.sizes[axis] - 1) * view.strides[axis];
6970
}
7071

7172
loc = next_axis == -1 ? -1 : loc + view.strides[next_axis];
7273
return *this;
7374
}
7475
iterator operator++(int)
7576
{
76-
auto tmp = *this;
77+
auto cached_iter = *this;
7778
int64_t next_axis = -1;
78-
int64_t idx;
79-
for (int64_t i = R; i > 0; --i) {
80-
idx = i - 1;
81-
RT_ASSERT(view.sizes[idx] > 0);
82-
if (indices[idx]++ < view.sizes[idx] - 1) {
83-
next_axis = idx;
79+
for (int64_t axis = R - 1; axis >= 0; axis--) {
80+
if (++indices[axis] < view.sizes[axis]) {
81+
next_axis = axis;
8482
break;
8583
}
86-
indices[idx] = 0;
87-
loc -= (view.sizes[idx] - 1) * view.strides[idx];
84+
indices[axis] = 0;
85+
loc -= view.sizes[axis] == 0 ? 0 : (view.sizes[axis] - 1) * view.strides[axis];
8886
}
8987

9088
loc = next_axis == -1 ? -1 : loc + view.strides[next_axis];
91-
return tmp;
89+
return cached_iter;
9290
}
9391
bool operator==(const iterator &other) const
9492
{
@@ -138,13 +136,15 @@ template <typename T, size_t R> class DataView {
138136

139137
size_t loc = offset;
140138
for (size_t axis = 0; axis < R; axis++) {
141-
RT_ASSERT(indices[axis] < sizes[axis]);
142139
loc += indices[axis] * strides[axis];
143140
}
144141
return data_aligned[loc];
145142
}
146143

147-
iterator begin() { return iterator{*this, static_cast<int64_t>(offset)}; }
144+
iterator begin()
145+
{
146+
return iterator{*this, (*this).size() == 0 ? -1 : static_cast<int64_t>(offset)};
147+
}
148148

149149
iterator end() { return iterator{*this, -1}; }
150150
};

runtime/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ endif()
5454

5555
add_executable(runner_tests_qir_runtime)
5656
target_sources(runner_tests_qir_runtime PRIVATE
57+
Test_DataView.cpp
5758
Test_NullQubit.cpp
5859
Test_ResourceTracker.cpp
5960
)

runtime/tests/Test_DataView.cpp

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
// Copyright 2025 Xanadu Quantum Technologies Inc.
2+
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <cstdio>
16+
17+
#include <catch2/catch_approx.hpp>
18+
#include <catch2/catch_test_macros.hpp>
19+
#include <catch2/matchers/catch_matchers_string.hpp>
20+
21+
#include "DataView.hpp"
22+
23+
using namespace Catch::Matchers;
24+
25+
using namespace Catalyst::Runtime;
26+
27+
TEST_CASE("Test DataView Pre-Increment Iterator - double, 1", "[DataView]")
28+
{
29+
double data_aligned[3] = {1.0, 1.1, 1.2};
30+
size_t offset = 0U;
31+
size_t sizes[1] = {3};
32+
size_t strides[1] = {1};
33+
34+
DataView<double, 1> view(data_aligned, offset, sizes, strides);
35+
36+
auto view_iter = view.begin();
37+
38+
CHECK(*view_iter == Catch::Approx(1.0).epsilon(1e-5));
39+
CHECK(*++view_iter == Catch::Approx(1.1).epsilon(1e-5));
40+
CHECK(*++view_iter == Catch::Approx(1.2).epsilon(1e-5));
41+
}
42+
43+
TEST_CASE("Test DataView Pre-Increment Iterator - int, 2", "[DataView]")
44+
{
45+
int data_aligned[3][3] = {{0, 1, 2}, {3, 4, 5}};
46+
size_t offset = 0U;
47+
size_t sizes[2] = {3, 3};
48+
size_t strides[2] = {3, 1};
49+
50+
DataView<int, 2> view(*data_aligned, offset, sizes, strides);
51+
52+
auto view_iter = view.begin();
53+
54+
CHECK(*view_iter == 0);
55+
for (int i = 1; i < 6; i++) {
56+
CHECK(*++view_iter == i);
57+
}
58+
}
59+
60+
TEST_CASE("Test DataView Pre-Increment Iterator - int, 3", "[DataView]")
61+
{
62+
int data_aligned[18] = {0};
63+
for (int i = 0; i < 18; i++) {
64+
data_aligned[i] = i;
65+
}
66+
67+
size_t offset = 0U;
68+
size_t sizes[3] = {3, 2, 3};
69+
size_t strides[3] = {6, 3, 1};
70+
71+
DataView<int, 3> view(data_aligned, offset, sizes, strides);
72+
73+
auto view_iter = view.begin();
74+
75+
CHECK(*view_iter == 0);
76+
for (int i = 1; i < 18; i++) {
77+
CHECK(*++view_iter == i);
78+
}
79+
}
80+
81+
TEST_CASE("Test DataView Post-Increment Iterator - double, 1", "[DataView]")
82+
{
83+
double data_aligned[3] = {3.2, 4.1, 1.6};
84+
size_t offset = 0;
85+
size_t sizes[1] = {3};
86+
size_t strides[1] = {1};
87+
88+
DataView<double, 1> view(data_aligned, offset, sizes, strides);
89+
90+
auto view_iter = view.begin();
91+
92+
CHECK(*view_iter++ == Catch::Approx(3.2).epsilon(1e-5));
93+
CHECK(*view_iter++ == Catch::Approx(4.1).epsilon(1e-5));
94+
CHECK(*view_iter == Catch::Approx(1.6).epsilon(1e-5));
95+
}
96+
97+
TEST_CASE("Test DataView Post-Increment Iterator - int, 2", "[DataView]")
98+
{
99+
int data_aligned[3][3] = {{0, 1, 2}, {3, 4, 5}};
100+
size_t offset = 0U;
101+
size_t sizes[2] = {3, 3};
102+
size_t strides[2] = {3, 1};
103+
104+
DataView<int, 2> view(*data_aligned, offset, sizes, strides);
105+
106+
auto view_iter = view.begin();
107+
108+
for (int i = 0; i < 6; i++) {
109+
CHECK(*view_iter++ == i);
110+
}
111+
}
112+
113+
TEST_CASE("Test DataView Post-Increment Iterator - int, 3", "[DataView]")
114+
{
115+
int data_aligned[18] = {0};
116+
for (int i = 0; i < 18; i++) {
117+
data_aligned[i] = i;
118+
}
119+
120+
size_t offset = 0U;
121+
size_t sizes[3] = {3, 2, 3};
122+
size_t strides[3] = {6, 3, 1};
123+
124+
DataView<int, 3> view(data_aligned, offset, sizes, strides);
125+
126+
auto view_iter = view.begin();
127+
128+
for (int i = 0; i < 18; i++) {
129+
CHECK(*view_iter++ == i);
130+
}
131+
}
132+
133+
TEST_CASE("DataView Iterator Distance 0 - 0 first axis", "[DataView]")
134+
{
135+
int *data_aligned = nullptr;
136+
size_t offset = 0;
137+
size_t sizes[2] = {0, 10};
138+
size_t strides[2] = {0, 0};
139+
140+
DataView<int, 2> view(data_aligned, offset, sizes, strides);
141+
142+
CHECK(std::distance(view.begin(), view.end()) == 0);
143+
}
144+
145+
TEST_CASE("DataView Iterator Distance 0 - 0 second axis", "[DataView]")
146+
{
147+
int *data_aligned = nullptr;
148+
size_t offset = 0;
149+
size_t sizes[2] = {10, 0};
150+
size_t strides[2] = {0, 0};
151+
152+
DataView<int, 2> view(data_aligned, offset, sizes, strides);
153+
154+
CHECK(std::distance(view.begin(), view.end()) == 0);
155+
}
156+
157+
TEST_CASE("DataView Iterator Distance 4 - int, 2", "[DataView]")
158+
{
159+
int data_aligned[2][2] = {{1, 2}, {3, 4}};
160+
size_t offset = 0;
161+
size_t sizes[2] = {2, 2};
162+
size_t strides[2] = {2, 1};
163+
164+
DataView<int, 2> view(*data_aligned, offset, sizes, strides);
165+
166+
CHECK(std::distance(view.begin(), view.end()) == 4);
167+
}
168+
169+
TEST_CASE("DataView Iterator Distance 12 - double, 3", "[DataView]")
170+
{
171+
double data_aligned[2][2][3] = {{{3.1, 2.6, 9.5}, {5.4, 2.3, 8.1}},
172+
{{9.8, 8.2, 7.2}, {0.7, 9.6, 6.6}}};
173+
size_t offset = 0;
174+
size_t sizes[3] = {2, 2, 3};
175+
size_t strides[3] = {6, 3, 1};
176+
177+
DataView<double, 3> view(**data_aligned, offset, sizes, strides);
178+
179+
CHECK(std::distance(view.begin(), view.end()) == 12);
180+
}
181+
182+
TEST_CASE("DataView Size - 0 first axis", "[DataView]")
183+
{
184+
int *data_aligned = nullptr;
185+
size_t offset = 0;
186+
size_t sizes[2] = {0, 10};
187+
size_t strides[2] = {0, 0};
188+
189+
DataView<int, 2> view(data_aligned, offset, sizes, strides);
190+
191+
CHECK(view.size() == 0);
192+
}

0 commit comments

Comments
 (0)