Skip to content

Commit 7e44902

Browse files
author
Diptorup Deb
committed
Address review comments.
- A static variable Range.UNDEFINED_DIMENSION is now used instead of the magic -1 value to indicate that the dimension extent is undefined. - An ndim_obj need not be created when boxing a Range.
1 parent 2929350 commit 7e44902

File tree

3 files changed

+22
-20
lines changed

3 files changed

+22
-20
lines changed

numba_dpex/core/kernel_interface/indexers.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ class Range(tuple):
2323
the behavior of `sycl::range`.
2424
"""
2525

26+
UNDEFINED_DIMENSION = -1
27+
2628
def __new__(cls, dim0, dim1=None, dim2=None):
2729
"""Constructs a 1, 2, or 3 dimensional range.
2830
@@ -107,8 +109,8 @@ def dim1(self) -> int:
107109
"""
108110
try:
109111
return self[1]
110-
except:
111-
return -1
112+
except IndexError:
113+
return Range.UNDEFINED_DIMENSION
112114

113115
@property
114116
def dim2(self) -> int:
@@ -120,8 +122,8 @@ def dim2(self) -> int:
120122
"""
121123
try:
122124
return self[2]
123-
except:
124-
return -1
125+
except IndexError:
126+
return Range.UNDEFINED_DIMENSION
125127

126128

127129
class NdRange:
@@ -220,10 +222,13 @@ def __repr__(self):
220222
return self.__str__()
221223

222224
def __eq__(self, other):
223-
return (
224-
self.global_range == other.global_range
225-
and self.local_range == other.local_range
226-
)
225+
if isinstance(other, NdRange):
226+
return (
227+
self.global_range == other.global_range
228+
and self.local_range == other.local_range
229+
)
230+
else:
231+
return False
227232

228233

229234
@intrinsic
@@ -245,12 +250,16 @@ def codegen(context, builder, sig, args):
245250
if not isinstance(sig.args[1], types.NoneType):
246251
range_struct.dim1 = dim1
247252
else:
248-
range_struct.dim1 = llvmir.Constant(llvmir.types.IntType(64), -1)
253+
range_struct.dim1 = llvmir.Constant(
254+
llvmir.types.IntType(64), Range.UNDEFINED_DIMENSION
255+
)
249256

250257
if not isinstance(sig.args[2], types.NoneType):
251258
range_struct.dim2 = dim2
252259
else:
253-
range_struct.dim2 = llvmir.Constant(llvmir.types.IntType(64), -1)
260+
range_struct.dim2 = llvmir.Constant(
261+
llvmir.types.IntType(64), Range.UNDEFINED_DIMENSION
262+
)
254263

255264
range_struct.ndim = llvmir.Constant(llvmir.types.IntType(64), typ.ndim)
256265

numba_dpex/core/types/range_types.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,6 @@ def box_range(typ, val, c):
169169
c.context, c.builder, value=val
170170
)
171171

172-
ndim_obj = c.box(types.int64, range_struct.ndim)
173-
with cgutils.early_exit_if_null(c.builder, stack, ndim_obj):
174-
c.builder.store(fail_obj, ret_ptr)
175172
dim0_obj = c.box(types.int64, range_struct.dim0)
176173
with cgutils.early_exit_if_null(c.builder, stack, dim0_obj):
177174
c.builder.store(fail_obj, ret_ptr)
@@ -184,7 +181,6 @@ def box_range(typ, val, c):
184181

185182
class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Range))
186183
with cgutils.early_exit_if_null(c.builder, stack, class_obj):
187-
c.pyapi.decref(ndim_obj)
188184
c.pyapi.decref(dim0_obj)
189185
c.pyapi.decref(dim1_obj)
190186
c.pyapi.decref(dim2_obj)
@@ -203,7 +199,6 @@ def box_range(typ, val, c):
203199
else:
204200
raise ValueError("Cannot unbox Range instance.")
205201

206-
c.pyapi.decref(ndim_obj)
207202
c.pyapi.decref(dim0_obj)
208203
c.pyapi.decref(dim1_obj)
209204
c.pyapi.decref(dim2_obj)

numba_dpex/core/typing/typeof.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,8 @@
1010

1111
from numba_dpex.utils import address_space
1212

13-
<<<<<<< HEAD
14-
from ..types.dpctl_types import DpctlSyclEvent, DpctlSyclQueue
15-
=======
1613
from ..kernel_interface.indexers import NdRange, Range
17-
from ..types.dpctl_types import DpctlSyclQueue
18-
>>>>>>> 1d5800cf8 (Adds Range and NdRange as supported types in numba_dpex.dpjit.)
14+
from ..types.dpctl_types import DpctlSyclEvent, DpctlSyclQueue
1915
from ..types.dpnp_ndarray_type import DpnpNdArray
2016
from ..types.range_types import NdRangeType, RangeType
2117
from ..types.usm_ndarray_type import USMNdArray
@@ -126,6 +122,8 @@ def typeof_dpctl_sycl_event(val, c):
126122
Returns: A numba_dpex.core.types.dpctl_types.DpctlSyclEvent instance.
127123
"""
128124
return DpctlSyclEvent(val)
125+
126+
129127
@typeof_impl.register(Range)
130128
def typeof_range(val, c):
131129
"""Registers the type inference implementation function for a

0 commit comments

Comments
 (0)