Skip to content

Commit 9cb0cc9

Browse files
committed
Clean up target_data & item size usage
1 parent 0a21ec2 commit 9cb0cc9

File tree

3 files changed

+12
-25
lines changed

3 files changed

+12
-25
lines changed

numba_dpex/kernel_api_impl/spirv/arrayobj.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,7 @@
1212
from llvmlite.ir.builder import IRBuilder
1313
from numba.core import cgutils, errors, types
1414
from numba.core.base import BaseContext
15-
16-
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext
17-
from numba_dpex.ocl.oclimpl import _get_target_data
18-
19-
20-
def get_itemsize(context: SPIRVTargetContext, array_type: types.Array):
21-
"""
22-
Return the item size for the given array or buffer type.
23-
Same as numba.np.arrayobj.get_itemsize, but using spirv data.
24-
"""
25-
targetdata = _get_target_data(context)
26-
lldtype = context.get_data_type(array_type.dtype)
27-
return lldtype.get_abi_size(targetdata)
15+
from numba.np.arrayobj import get_itemsize
2816

2917

3018
def require_literal(literal_type: types.Type):
@@ -47,14 +35,18 @@ def require_literal(literal_type: types.Type):
4735

4836

4937
def make_spirv_array( # pylint: disable=too-many-arguments
50-
context: SPIRVTargetContext,
38+
context: BaseContext,
5139
builder: IRBuilder,
5240
ty_array: types.Array,
5341
ty_shape: Union[types.IntegerLiteral, types.BaseTuple],
5442
shape: llvmir.Value,
5543
data: llvmir.Value,
5644
):
57-
"""Makes SPIR-V array and fills it data."""
45+
"""Makes SPIR-V array and fills it data.
46+
47+
Generic version of numba.np.arrayobj.np_cfarray so that it can be used
48+
not only as intrinsic, but inside instruction generation.
49+
"""
5850
# Create array object
5951
ary = context.make_array(ty_array)(context, builder)
6052

@@ -112,7 +104,7 @@ def allocate_array_data_on_stack(
112104

113105

114106
def make_spirv_generic_array_on_stack(
115-
context: SPIRVTargetContext,
107+
context: BaseContext,
116108
builder: IRBuilder,
117109
ty_array: types.Array,
118110
ty_shape: Union[types.IntegerLiteral, types.BaseTuple],

numba_dpex/kernel_api_impl/spirv/target.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from llvmlite import ir as llvmir
1414
from numba.core import cgutils, funcdesc
1515
from numba.core import types as nb_types
16-
from numba.core import typing, utils
16+
from numba.core import typing
1717
from numba.core.base import BaseContext
1818
from numba.core.callconv import MinimalCallConv
1919
from numba.core.target_extension import GPU, target_registry
@@ -150,7 +150,7 @@ def init(self):
150150

151151
self._internal_codegen = codegen.JITSPIRVCodegen("numba_dpex.kernel")
152152
self._target_data = ll.create_target_data(
153-
codegen.SPIR_DATA_LAYOUT[utils.MACHINE_BITS]
153+
codegen.SPIR_DATA_LAYOUT[self.address_size]
154154
)
155155

156156
# Override data model manager to SPIR model

numba_dpex/ocl/oclimpl.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from numba.core import cgutils, types
1212
from numba.core.imputils import Registry
1313
from numba.core.typing.npydecl import parse_dtype
14+
from numba.np.arrayobj import get_itemsize
1415

1516
from numba_dpex import spirv_kernel_target
1617
from numba_dpex.core import config
@@ -403,9 +404,7 @@ def _make_array(
403404
aryty = Array(dtype=dtype, ndim=ndim, layout="C", addrspace=addrspace)
404405
ary = context.make_array(aryty)(context, builder)
405406

406-
targetdata = _get_target_data(context)
407-
lldtype = context.get_data_type(dtype)
408-
itemsize = lldtype.get_abi_size(targetdata)
407+
itemsize = get_itemsize(context, aryty)
409408
# Compute strides
410409
rstrides = [itemsize]
411410
for i, lastsize in enumerate(reversed(shape[1:])):
@@ -424,7 +423,3 @@ def _make_array(
424423
)
425424

426425
return ary._getvalue()
427-
428-
429-
def _get_target_data(context):
430-
return ll.create_target_data(SPIR_DATA_LAYOUT[context.address_size])

0 commit comments

Comments
 (0)