44
55import ctypes
66
7- import numba
87from llvmlite import ir # noqa: F401
98from numba import cuda , types # noqa: F401
109from numba .core .datamodel .registry import default_manager # noqa: F401
11- from numba .core .extending import intrinsic # noqa: F401
10+ from numba .core .extending import as_numba_type , intrinsic # noqa: F401
1211
13- from .._utils .protocols import get_dtype
14- from ..struct import gpu_struct_from_numba_types
12+ from ..struct import make_struct_type
1513from ._iterators import (
1614 IteratorBase ,
1715 IteratorKind ,
@@ -36,22 +34,18 @@ def _get_zip_iterator_metadata(iterators):
3634 # this iterator's state is a struct composed of the states of the input iterators:
3735 state_field_names = tuple (f"state_{ i } " for i in range (n_iterators ))
3836 state_field_types = tuple (it .state_type for it in iterators )
39- ZipState = gpu_struct_from_numba_types (
40- "ZipState" , state_field_names , state_field_types
41- )
37+ ZipState = make_struct_type ("ZipState" , state_field_names , state_field_types )
4238
4339 # this iterator's value is a struct composed of the values of the input iterators:
4440 value_field_names = tuple (f"value_{ i } " for i in range (n_iterators ))
4541 value_field_types = tuple (it .value_type for it in iterators )
46- ZipValue = gpu_struct_from_numba_types (
47- "ZipValue" , value_field_names , value_field_types
48- )
42+ ZipValue = make_struct_type ("ZipValue" , value_field_names , value_field_types )
4943
5044 cvalue = ZipCValueStruct (
5145 ** {f"iter_{ i } " : it .cvalue for i , it in enumerate (iterators )}
5246 )
53- state_type = ZipState . _numba_type
54- value_type = ZipValue . _numba_type
47+ state_type = as_numba_type ( ZipState )
48+ value_type = as_numba_type ( ZipValue )
5549 return cvalue , state_type , value_type , ZipValue
5650
5751
@@ -64,6 +58,15 @@ def _get_advance_and_dereference_functions(iterators):
6458
6559 n_iterators = len (iterators )
6660
61+ # Create a local namespace for this zip iterator to avoid polluting globals
62+ # and prevent name collisions when nesting zip iterators
63+ local_ns = {
64+ "intrinsic" : intrinsic ,
65+ "types" : types ,
66+ "ir" : ir ,
67+ "default_manager" : default_manager ,
68+ }
69+
6770 # Within the advance and dereference methods of this iterator, we need a way
6871 # to get pointers to the fields of the state struct (advance and dereference),
6972 # and the value struct (dereference). This needs `n` custom intrinsics, one
@@ -80,30 +83,29 @@ def codegen(context, builder, sig, args):
8083 # Use GEP to get pointer to field at index { field_idx }
8184 field_ptr = builder.gep(
8285 struct_ptr,
83- [ir.Constant(ir.IntType(32), 0), ir.Constant(ir.IntType(32), { field_idx } )],
86+ [ir.Constant(ir.IntType(32), 0), ir.Constant(
87+ ir.IntType(32), { field_idx } )],
8488 )
8589 return field_ptr
8690
8791 struct_model = default_manager.lookup(struct_ptr_type.dtype)
8892 field_type = struct_model._members[{ field_idx } ]
8993 return types.CPointer(field_type)(struct_ptr_type), codegen
9094"""
91- # Execute the code to create the intrinsic function in global namespace
92- exec (intrinsic_code , globals () )
95+ # Execute the code to create the intrinsic function in local namespace
96+ exec (intrinsic_code , local_ns )
9397
9498 # Now we can define the advance and dereference methods of this iterator,
9599 # which also need to be defined dynamically because they will use the
96100 # intrinsics defined above.
97101 for i , it in enumerate (iterators ):
98- globals ()[f"advance_{ i } " ] = cuda .jit (it .advance , device = True )
99- globals ()[f"input_dereference_{ i } " ] = cuda .jit (
100- it .input_dereference , device = True
101- )
102+ local_ns [f"advance_{ i } " ] = cuda .jit (it .advance , device = True )
103+ local_ns [f"input_dereference_{ i } " ] = cuda .jit (it .input_dereference , device = True )
102104 # Also compile output_dereference if available
103105 try :
104106 output_deref = it .output_dereference
105107 if output_deref is not None :
106- globals () [f"output_dereference_{ i } " ] = cuda .jit (
108+ local_ns [f"output_dereference_{ i } " ] = cuda .jit (
107109 output_deref , device = True
108110 )
109111 except AttributeError :
@@ -152,12 +154,12 @@ def input_dereference(state, result):
152154{ chr (10 ).join (input_dereference_lines )}
153155""" # chr(10) is '\n'
154156
155- # Execute the method codes:
156- exec (advance_method_code , globals () )
157- exec (input_dereference_method_code , globals () )
157+ # Execute the method codes in local namespace :
158+ exec (advance_method_code , local_ns )
159+ exec (input_dereference_method_code , local_ns )
158160
159- advance_func = globals () ["input_advance" ]
160- input_dereference_func = globals () ["input_dereference" ]
161+ advance_func = local_ns ["input_advance" ]
162+ input_dereference_func = local_ns ["input_dereference" ]
161163
162164 # Generate output_dereference if all iterators support it
163165 output_dereference_func = None
@@ -167,8 +169,8 @@ def output_dereference(state, x):
167169 # Write to each iterator using dynamically created field pointer functions
168170{ chr (10 ).join (output_dereference_lines )}
169171"""
170- exec (output_dereference_method_code , globals () )
171- output_dereference_func = globals () ["output_dereference" ]
172+ exec (output_dereference_method_code , local_ns )
173+ output_dereference_func = local_ns ["output_dereference" ]
172174
173175 return advance_func , input_dereference_func , output_dereference_func
174176
@@ -183,14 +185,16 @@ def make_zip_iterator(*iterators):
183185 Returns:
184186 ZipIterator: Iterator that combines all input iterators
185187 """
188+ from .._cccl_interop import get_value_type
189+
186190 if len (iterators ) < 1 :
187191 raise ValueError ("At least 1 iterator is required" )
188192
189193 # Convert arrays to iterators if needed
190194 processed_iterators = []
191195 for it in iterators :
192196 if hasattr (it , "__cuda_array_interface__" ):
193- it = pointer (it , numba . from_dtype ( get_dtype ( it ) ))
197+ it = pointer (it , get_value_type ( it ))
194198 processed_iterators .append (it )
195199
196200 # Validate all are iterators
0 commit comments