Skip to content

Commit 2e98940

Browse files
authored
Support nested structs in cuda.compute (#6353)
* Enable nesting of gpu_structs * Update existing usage of gpu_struct to * Enable nesting of ZipIterators * Add tests for nested structs and ZipIterators * Reorganize struct examples and add new examples for nested structs * Address review comments --------- Co-authored-by: Ashwin Srinath <shwina@users.noreply.github.com>
1 parent 2f82c4d commit 2e98940

17 files changed

+1328
-235
lines changed

python/cuda_cccl/cuda/compute/_cccl_interop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def to_cccl_op(op: Callable | OpKind, sig: Signature | None) -> Op:
313313

314314

315315
def get_value_type(d_in: IteratorBase | DeviceArrayLike):
316-
from .struct import gpu_struct_from_numpy_dtype
316+
from .struct import gpu_struct
317317

318318
if isinstance(d_in, IteratorBase):
319319
return d_in.value_type
@@ -323,7 +323,7 @@ def get_value_type(d_in: IteratorBase | DeviceArrayLike):
323323
# types directly, as those are passed by pointer to device
324324
# functions. Instead, we create an anonymous struct type
325325
# which has the appropriate pass-by-value semantics.
326-
return as_numba_type(gpu_struct_from_numpy_dtype("anonymous", dtype))
326+
return as_numba_type(gpu_struct(dtype))
327327
return numba.from_dtype(dtype)
328328

329329

python/cuda_cccl/cuda/compute/algorithms/_scan.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,18 @@ def __init__(
7171
init_value_type_info = self.d_in_cccl.value_type
7272

7373
case _bindings.InitKind.FUTURE_VALUE_INIT:
74-
self.init_value_cccl = cccl.to_cccl_input_iter(init_value)
74+
self.init_value_cccl = cccl.to_cccl_input_iter(
75+
cast(DeviceArrayLike, init_value)
76+
)
7577
value_type = numba.from_dtype(
7678
protocols.get_dtype(cast(DeviceArrayLike, init_value))
7779
)
7880
init_value_type_info = self.init_value_cccl.value_type
7981

8082
case _bindings.InitKind.VALUE_INIT:
81-
self.init_value_cccl = cccl.to_cccl_value(init_value)
83+
self.init_value_cccl = cccl.to_cccl_value(
84+
cast(np.ndarray | GpuStruct, init_value)
85+
)
8286
value_type = (
8387
numba.from_dtype(init_value.dtype)
8488
if isinstance(init_value, np.ndarray)
@@ -141,7 +145,9 @@ def __call__(
141145

142146
case _bindings.InitKind.VALUE_INIT:
143147
self.init_value_cccl = cast(_bindings.Value, self.init_value_cccl)
144-
self.init_value_cccl.state = to_cccl_value_state(init_value)
148+
self.init_value_cccl.state = to_cccl_value_state(
149+
cast(np.ndarray | GpuStruct, init_value)
150+
)
145151

146152
stream_handle = validate_and_get_stream(stream)
147153

python/cuda_cccl/cuda/compute/algorithms/_transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def unary_transform(
273273
When working with custom struct types, you need to provide type annotations
274274
to help with type inference. See the binary transform struct example for reference:
275275
276-
.. literalinclude:: ../../python/cuda_cccl/tests/compute/examples/transform/binary_transform_struct.py
276+
.. literalinclude:: ../../python/cuda_cccl/tests/compute/examples/struct/struct_transform.py
277277
:language: python
278278
:start-after: # example-begin
279279
@@ -312,7 +312,7 @@ def binary_transform(
312312
When working with custom struct types, you need to provide type annotations
313313
to help with type inference. See the following example:
314314
315-
.. literalinclude:: ../../python/cuda_cccl/tests/compute/examples/transform/binary_transform_struct.py
315+
.. literalinclude:: ../../python/cuda_cccl/tests/compute/examples/struct/struct_transform.py
316316
:language: python
317317
:start-after: # example-begin
318318

python/cuda_cccl/cuda/compute/iterators/_factories.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,12 @@ def ZipIterator(*iterators):
232232
:language: python
233233
:start-after: # example-begin
234234
235+
ZipIterator can also be used with nested gpu_struct types:
236+
237+
.. literalinclude:: ../../python/cuda_cccl/tests/compute/examples/struct/nested_struct_zip_iterator.py
238+
:language: python
239+
:start-after: # example-begin
240+
235241
Args:
236242
*iterators: Variable number of iterators to zip (at least 1)
237243

python/cuda_cccl/cuda/compute/iterators/_iterators.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
from llvmlite import ir
1010
from numba import cuda, types
11-
from numba.core.extending import intrinsic, overload
11+
from numba.core.extending import as_numba_type, intrinsic, overload
1212
from numba.core.typing.ctypes_utils import to_ctypes
1313
from numba.cuda.dispatcher import CUDADispatcher
1414

@@ -680,6 +680,8 @@ def make_permutation_iterator(values, indices):
680680
Returns:
681681
PermutationIterator: Iterator that yields permuted values
682682
"""
683+
from ..struct import make_struct_type
684+
683685
# Convert arrays to iterators if needed
684686
if hasattr(values, "__cuda_array_interface__"):
685687
values = pointer(values, numba.from_dtype(get_dtype(values)))
@@ -715,22 +717,20 @@ def make_permutation_iterator(values, indices):
715717
# The cvalue and state for PermutationIterator are
716718
# structs composed of the cvalues and states of the
717719
# value and index iterators.
718-
from ..struct import gpu_struct_from_numba_types
719-
720720
class PermutationCValueStruct(ctypes.Structure):
721721
_fields_ = [
722722
("value_state", values.cvalue.__class__),
723723
("index_state", indices.cvalue.__class__),
724724
]
725725

726-
PermutationState = gpu_struct_from_numba_types(
726+
PermutationState = make_struct_type(
727727
"PermutationState",
728-
("value_state", "index_state"),
729-
(values_state_type, indices.state_type),
728+
field_names=("value_state", "index_state"),
729+
field_types=(values_state_type, indices.state_type),
730730
)
731731

732732
cvalue = PermutationCValueStruct(values.cvalue, indices.cvalue)
733-
state_type = PermutationState._numba_type
733+
state_type = as_numba_type(PermutationState)
734734
value_type = value_dtype
735735

736736
# Define intrinsics for accessing struct fields

python/cuda_cccl/cuda/compute/iterators/_zip_iterator.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@
44

55
import ctypes
66

7-
import numba
87
from llvmlite import ir # noqa: F401
98
from numba import cuda, types # noqa: F401
109
from numba.core.datamodel.registry import default_manager # noqa: F401
11-
from numba.core.extending import intrinsic # noqa: F401
10+
from numba.core.extending import as_numba_type, intrinsic # noqa: F401
1211

13-
from .._utils.protocols import get_dtype
14-
from ..struct import gpu_struct_from_numba_types
12+
from ..struct import make_struct_type
1513
from ._iterators import (
1614
IteratorBase,
1715
IteratorKind,
@@ -36,22 +34,18 @@ def _get_zip_iterator_metadata(iterators):
3634
# this iterator's state is a struct composed of the states of the input iterators:
3735
state_field_names = tuple(f"state_{i}" for i in range(n_iterators))
3836
state_field_types = tuple(it.state_type for it in iterators)
39-
ZipState = gpu_struct_from_numba_types(
40-
"ZipState", state_field_names, state_field_types
41-
)
37+
ZipState = make_struct_type("ZipState", state_field_names, state_field_types)
4238

4339
# this iterator's value is a struct composed of the values of the input iterators:
4440
value_field_names = tuple(f"value_{i}" for i in range(n_iterators))
4541
value_field_types = tuple(it.value_type for it in iterators)
46-
ZipValue = gpu_struct_from_numba_types(
47-
"ZipValue", value_field_names, value_field_types
48-
)
42+
ZipValue = make_struct_type("ZipValue", value_field_names, value_field_types)
4943

5044
cvalue = ZipCValueStruct(
5145
**{f"iter_{i}": it.cvalue for i, it in enumerate(iterators)}
5246
)
53-
state_type = ZipState._numba_type
54-
value_type = ZipValue._numba_type
47+
state_type = as_numba_type(ZipState)
48+
value_type = as_numba_type(ZipValue)
5549
return cvalue, state_type, value_type, ZipValue
5650

5751

@@ -64,6 +58,15 @@ def _get_advance_and_dereference_functions(iterators):
6458

6559
n_iterators = len(iterators)
6660

61+
# Create a local namespace for this zip iterator to avoid polluting globals
62+
# and prevent name collisions when nesting zip iterators
63+
local_ns = {
64+
"intrinsic": intrinsic,
65+
"types": types,
66+
"ir": ir,
67+
"default_manager": default_manager,
68+
}
69+
6770
# Within the advance and dereference methods of this iterator, we need a way
6871
# to get pointers to the fields of the state struct (advance and dereference),
6972
# and the value struct (dereference). This needs `n` custom intrinsics, one
@@ -80,30 +83,29 @@ def codegen(context, builder, sig, args):
8083
# Use GEP to get pointer to field at index {field_idx}
8184
field_ptr = builder.gep(
8285
struct_ptr,
83-
[ir.Constant(ir.IntType(32), 0), ir.Constant(ir.IntType(32), {field_idx})],
86+
[ir.Constant(ir.IntType(32), 0), ir.Constant(
87+
ir.IntType(32), {field_idx})],
8488
)
8589
return field_ptr
8690
8791
struct_model = default_manager.lookup(struct_ptr_type.dtype)
8892
field_type = struct_model._members[{field_idx}]
8993
return types.CPointer(field_type)(struct_ptr_type), codegen
9094
"""
91-
# Execute the code to create the intrinsic function in global namespace
92-
exec(intrinsic_code, globals())
95+
# Execute the code to create the intrinsic function in local namespace
96+
exec(intrinsic_code, local_ns)
9397

9498
# Now we can define the advance and dereference methods of this iterator,
9599
# which also need to be defined dynamically because they will use the
96100
# intrinsics defined above.
97101
for i, it in enumerate(iterators):
98-
globals()[f"advance_{i}"] = cuda.jit(it.advance, device=True)
99-
globals()[f"input_dereference_{i}"] = cuda.jit(
100-
it.input_dereference, device=True
101-
)
102+
local_ns[f"advance_{i}"] = cuda.jit(it.advance, device=True)
103+
local_ns[f"input_dereference_{i}"] = cuda.jit(it.input_dereference, device=True)
102104
# Also compile output_dereference if available
103105
try:
104106
output_deref = it.output_dereference
105107
if output_deref is not None:
106-
globals()[f"output_dereference_{i}"] = cuda.jit(
108+
local_ns[f"output_dereference_{i}"] = cuda.jit(
107109
output_deref, device=True
108110
)
109111
except AttributeError:
@@ -152,12 +154,12 @@ def input_dereference(state, result):
152154
{chr(10).join(input_dereference_lines)}
153155
""" # chr(10) is '\n'
154156

155-
# Execute the method codes:
156-
exec(advance_method_code, globals())
157-
exec(input_dereference_method_code, globals())
157+
# Execute the method codes in local namespace:
158+
exec(advance_method_code, local_ns)
159+
exec(input_dereference_method_code, local_ns)
158160

159-
advance_func = globals()["input_advance"]
160-
input_dereference_func = globals()["input_dereference"]
161+
advance_func = local_ns["input_advance"]
162+
input_dereference_func = local_ns["input_dereference"]
161163

162164
# Generate output_dereference if all iterators support it
163165
output_dereference_func = None
@@ -167,8 +169,8 @@ def output_dereference(state, x):
167169
# Write to each iterator using dynamically created field pointer functions
168170
{chr(10).join(output_dereference_lines)}
169171
"""
170-
exec(output_dereference_method_code, globals())
171-
output_dereference_func = globals()["output_dereference"]
172+
exec(output_dereference_method_code, local_ns)
173+
output_dereference_func = local_ns["output_dereference"]
172174

173175
return advance_func, input_dereference_func, output_dereference_func
174176

@@ -183,14 +185,16 @@ def make_zip_iterator(*iterators):
183185
Returns:
184186
ZipIterator: Iterator that combines all input iterators
185187
"""
188+
from .._cccl_interop import get_value_type
189+
186190
if len(iterators) < 1:
187191
raise ValueError("At least 1 iterator is required")
188192

189193
# Convert arrays to iterators if needed
190194
processed_iterators = []
191195
for it in iterators:
192196
if hasattr(it, "__cuda_array_interface__"):
193-
it = pointer(it, numba.from_dtype(get_dtype(it)))
197+
it = pointer(it, get_value_type(it))
194198
processed_iterators.append(it)
195199

196200
# Validate all are iterators

python/cuda_cccl/cuda/compute/numba_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numba
44
from numba import cuda
5+
from numba.core.extending import as_numba_type
56
from numpy.typing import DTypeLike
67

78
from .typing import GpuStruct
@@ -16,8 +17,8 @@ def to_numba_type(tp: GpuStruct | DTypeLike) -> numba.types.Type:
1617
"""
1718
Convert a GpuStruct or DtypeLike to a numba type.
1819
"""
19-
if hasattr(tp, "_numba_type"):
20-
return tp._numba_type # type: ignore[union-attr]
20+
if value := as_numba_type.lookup.get(tp):
21+
return value
2122
return numba.from_dtype(tp)
2223

2324

0 commit comments

Comments
 (0)