Skip to content

Commit a2d016a

Browse files
committed
Clean up private array generation
1 parent 9cb0cc9 commit a2d016a

File tree

2 files changed

+20
-36
lines changed

2 files changed

+20
-36
lines changed

numba_dpex/kernel_api_impl/spirv/arrayobj.py

Lines changed: 5 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44

55
"""Contains SPIR-V specific array functions."""
66

7-
import operator
8-
from functools import reduce
97
from typing import Union
108

119
import llvmlite.ir as llvmir
@@ -34,18 +32,21 @@ def require_literal(literal_type: types.Type):
3432
)
3533

3634

37-
def make_spirv_array( # pylint: disable=too-many-arguments
35+
def np_cfarray( # pylint: disable=too-many-arguments
3836
context: BaseContext,
3937
builder: IRBuilder,
4038
ty_array: types.Array,
4139
ty_shape: Union[types.IntegerLiteral, types.BaseTuple],
4240
shape: llvmir.Value,
4341
data: llvmir.Value,
4442
):
45-
"""Makes SPIR-V array and fills it data.
43+
"""Makes numpy-like array and fills it's data depending on the context's
44+
datamodel.
4645
4746
Generic version of numba.np.arrayobj.np_cfarray so that it can be used
4847
not only as intrinsic, but inside instruction generation.
48+
49+
TODO: upstream changes to numba.
4950
"""
5051
# Create array object
5152
ary = context.make_array(ty_array)(context, builder)
@@ -84,32 +85,3 @@ def make_spirv_array( # pylint: disable=too-many-arguments
8485
)
8586

8687
return ary
87-
88-
89-
def allocate_array_data_on_stack(
90-
context: BaseContext,
91-
builder: IRBuilder,
92-
ty_array: types.Array,
93-
ty_shape: Union[types.IntegerLiteral, types.BaseTuple],
94-
):
95-
"""Allocates flat array of given shape on the stack."""
96-
if not isinstance(ty_shape, types.BaseTuple):
97-
ty_shape = (ty_shape,)
98-
99-
return cgutils.alloca_once(
100-
builder,
101-
context.get_data_type(ty_array.dtype),
102-
size=reduce(operator.mul, [s.literal_value for s in ty_shape]),
103-
)
104-
105-
106-
def make_spirv_generic_array_on_stack(
107-
context: BaseContext,
108-
builder: IRBuilder,
109-
ty_array: types.Array,
110-
ty_shape: Union[types.IntegerLiteral, types.BaseTuple],
111-
shape: llvmir.Value,
112-
):
113-
"""Makes SPIR-V array of given shape with empty data."""
114-
data = allocate_array_data_on_stack(context, builder, ty_array, ty_shape)
115-
return make_spirv_array(context, builder, ty_array, ty_shape, shape, data)

numba_dpex/kernel_api_impl/spirv/overloads/_private_array_overloads.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
"""
88

99

10+
import operator
11+
from functools import reduce
12+
1013
import llvmlite.ir as llvmir
1114
from llvmlite.ir.builder import IRBuilder
1215
from numba.core import cgutils, types
@@ -18,7 +21,7 @@
1821
from numba_dpex.core.types import USMNdArray
1922
from numba_dpex.kernel_api import PrivateArray
2023
from numba_dpex.kernel_api_impl.spirv.arrayobj import (
21-
make_spirv_generic_array_on_stack,
24+
np_cfarray,
2225
require_literal,
2326
)
2427
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTypingContext
@@ -74,10 +77,19 @@ def dpex_private_array_lower(
7477
fill_zeros = False
7578
ty_array = sig.return_type
7679

77-
ary = make_spirv_generic_array_on_stack(
78-
context, builder, ty_array, ty_shape, shape
80+
# Allocate data on stack
81+
data = cgutils.alloca_once(
82+
builder,
83+
context.get_data_type(ty_array.dtype),
84+
size=(
85+
reduce(operator.mul, [s.literal_value for s in ty_shape])
86+
if isinstance(ty_shape, types.BaseTuple)
87+
else ty_shape.literal_value
88+
),
7989
)
8090

91+
ary = np_cfarray(context, builder, ty_array, ty_shape, shape, data)
92+
8193
if fill_zeros:
8294
cgutils.memset(
8395
builder, ary.data, builder.mul(ary.itemsize, ary.nitems), 0

0 commit comments

Comments
 (0)