1
+ # SPDX-FileCopyrightText: 2012 - 2024 Anaconda Inc.
1
2
# SPDX-FileCopyrightText: 2024 Intel Corporation
2
3
#
3
4
# SPDX-License-Identifier: Apache-2.0
5
+ # SPDX-License-Identifier: BSD-2-Clause
4
6
5
7
"""Contains SPIR-V specific array functions."""
6
8
7
- import operator
8
- from functools import reduce
9
9
from typing import Union
10
10
11
11
import llvmlite .ir as llvmir
12
12
from llvmlite .ir .builder import IRBuilder
13
13
from numba .core import cgutils , errors , types
14
14
from numba .core .base import BaseContext
15
+ from numba .np .arrayobj import (
16
+ basic_indexing ,
17
+ get_itemsize ,
18
+ load_item ,
19
+ make_array ,
20
+ )
15
21
16
- from numba_dpex .kernel_api_impl .spirv .target import SPIRVTargetContext
17
- from numba_dpex .ocl .oclimpl import _get_target_data
22
+ from numba_dpex .core .types import USMNdArray
18
23
19
24
20
- def get_itemsize (context : SPIRVTargetContext , array_type : types .Array ):
25
+ def populate_array (
26
+ arraystruct , data , shape , strides , itemsize
27
+ ): # pylint: disable=too-many-arguments,too-many-locals
21
28
"""
22
- Return the item size for the given array or buffer type.
23
- Same as numba.np.arrayobj.get_itemsize, but using spirv data.
29
+ Helper function for populating array structures.
30
+
31
+ The function is copied from upstream Numba and modified to support the
32
+ USMNdArray data type that uses a different data model on SYCL devices
33
+ than the upstream types.Array data type. USMNdArray data model does not
34
+ have the ``parent`` and ``meminfo`` fields. This function intended to be
35
+ used only in the SPIRVKernelTarget.
36
+
37
+ *shape* and *strides* can be Python tuples or LLVM arrays.
38
+ """
39
+ context = arraystruct ._context # pylint: disable=protected-access
40
+ builder = arraystruct ._builder # pylint: disable=protected-access
41
+ datamodel = arraystruct ._datamodel # pylint: disable=protected-access
42
+ # doesn't matter what this array type instance is, it's just to get the
43
+ # fields for the data model of the standard array type in this context
44
+ standard_array = USMNdArray (ndim = 1 , layout = "C" , dtype = types .float64 )
45
+ standard_array_type_datamodel = context .data_model_manager [standard_array ]
46
+ required_fields = set (standard_array_type_datamodel ._fields )
47
+ datamodel_fields = set (datamodel ._fields )
48
+ # Make sure that the presented array object has a data model that is
49
+ # close enough to an array for this function to proceed.
50
+ if (required_fields & datamodel_fields ) != required_fields :
51
+ missing = required_fields - datamodel_fields
52
+ msg = (
53
+ f"The datamodel for type { arraystruct } is missing "
54
+ f"field{ 's' if len (missing ) > 1 else '' } { missing } ."
55
+ )
56
+ raise ValueError (msg )
57
+
58
+ intp_t = context .get_value_type (types .intp )
59
+ if isinstance (shape , (tuple , list )):
60
+ shape = cgutils .pack_array (builder , shape , intp_t )
61
+ if isinstance (strides , (tuple , list )):
62
+ strides = cgutils .pack_array (builder , strides , intp_t )
63
+ if isinstance (itemsize , int ):
64
+ itemsize = intp_t (itemsize )
65
+
66
+ attrs = {
67
+ "shape" : shape ,
68
+ "strides" : strides ,
69
+ "data" : data ,
70
+ "itemsize" : itemsize ,
71
+ }
72
+
73
+ # Calc num of items from shape
74
+ nitems = context .get_constant (types .intp , 1 )
75
+ unpacked_shape = cgutils .unpack_tuple (builder , shape , shape .type .count )
76
+ # (note empty shape => 0d array therefore nitems = 1)
77
+ for axlen in unpacked_shape :
78
+ nitems = builder .mul (nitems , axlen , flags = ["nsw" ])
79
+ attrs ["nitems" ] = nitems
80
+
81
+ # Make sure that we have all the fields
82
+ got_fields = set (attrs .keys ())
83
+ if got_fields != required_fields :
84
+ raise ValueError (f"missing { required_fields - got_fields } " )
85
+
86
+ # Set field value
87
+ for k , v in attrs .items ():
88
+ setattr (arraystruct , k , v )
89
+
90
+ return arraystruct
91
+
92
+
93
+ def make_view (
94
+ context , builder , ary , return_type , data , shapes , strides
95
+ ): # pylint: disable=too-many-arguments
96
+ """
97
+ Build a view over the given array with the given parameters.
98
+
99
+ This is analog of numpy.np.arrayobj.make_view without parent and
100
+ meminfo fields, because they don't make sense on device. This function
101
+ intended to be used only in kernel targets.
102
+ """
103
+ retary = make_array (return_type )(context , builder )
104
+ context .populate_array (
105
+ retary , data = data , shape = shapes , strides = strides , itemsize = ary .itemsize
106
+ )
107
+ return retary
108
+
109
+
110
+ def _getitem_array_generic (
111
+ context , builder , return_type , aryty , ary , index_types , indices
112
+ ): # pylint: disable=too-many-arguments
24
113
"""
25
- targetdata = _get_target_data (context )
26
- lldtype = context .get_data_type (array_type .dtype )
27
- return lldtype .get_abi_size (targetdata )
114
+ Return the result of indexing *ary* with the given *indices*,
115
+ returning either a scalar or a view.
116
+
117
+ This is analog of numpy.np.arrayobj._getitem_array_generic without parent
118
+ and meminfo fields, because they don't make sense on device. This function
119
+ intended to be used only in kernel targets.
120
+ """
121
+ dataptr , view_shapes , view_strides = basic_indexing (
122
+ context ,
123
+ builder ,
124
+ aryty ,
125
+ ary ,
126
+ index_types ,
127
+ indices ,
128
+ boundscheck = context .enable_boundscheck ,
129
+ )
130
+
131
+ if isinstance (return_type , types .Buffer ):
132
+ # Build array view
133
+ retary = make_view (
134
+ context ,
135
+ builder ,
136
+ ary ,
137
+ return_type ,
138
+ dataptr ,
139
+ view_shapes ,
140
+ view_strides ,
141
+ )
142
+ return retary ._getvalue () # pylint: disable=protected-access
143
+
144
+ # Load scalar from 0-d result
145
+ assert not view_shapes
146
+ return load_item (context , builder , aryty , dataptr )
28
147
29
148
30
149
def require_literal (literal_type : types .Type ):
@@ -46,15 +165,22 @@ def require_literal(literal_type: types.Type):
46
165
)
47
166
48
167
49
- def make_spirv_array ( # pylint: disable=too-many-arguments
50
- context : SPIRVTargetContext ,
168
+ def np_cfarray ( # pylint: disable=too-many-arguments
169
+ context : BaseContext ,
51
170
builder : IRBuilder ,
52
171
ty_array : types .Array ,
53
172
ty_shape : Union [types .IntegerLiteral , types .BaseTuple ],
54
173
shape : llvmir .Value ,
55
174
data : llvmir .Value ,
56
175
):
57
- """Makes SPIR-V array and fills it data."""
176
+ """Makes numpy-like array and fills it's data depending on the context's
177
+ datamodel.
178
+
179
+ Generic version of numba.np.arrayobj.np_cfarray so that it can be used
180
+ not only as intrinsic, but inside instruction generation.
181
+
182
+ TODO: upstream changes to numba.
183
+ """
58
184
# Create array object
59
185
ary = context .make_array (ty_array )(context , builder )
60
186
@@ -92,32 +218,3 @@ def make_spirv_array( # pylint: disable=too-many-arguments
92
218
)
93
219
94
220
return ary
95
-
96
-
97
- def allocate_array_data_on_stack (
98
- context : BaseContext ,
99
- builder : IRBuilder ,
100
- ty_array : types .Array ,
101
- ty_shape : Union [types .IntegerLiteral , types .BaseTuple ],
102
- ):
103
- """Allocates flat array of given shape on the stack."""
104
- if not isinstance (ty_shape , types .BaseTuple ):
105
- ty_shape = (ty_shape ,)
106
-
107
- return cgutils .alloca_once (
108
- builder ,
109
- context .get_data_type (ty_array .dtype ),
110
- size = reduce (operator .mul , [s .literal_value for s in ty_shape ]),
111
- )
112
-
113
-
114
- def make_spirv_generic_array_on_stack (
115
- context : SPIRVTargetContext ,
116
- builder : IRBuilder ,
117
- ty_array : types .Array ,
118
- ty_shape : Union [types .IntegerLiteral , types .BaseTuple ],
119
- shape : llvmir .Value ,
120
- ):
121
- """Makes SPIR-V array of given shape with empty data."""
122
- data = allocate_array_data_on_stack (context , builder , ty_array , ty_shape )
123
- return make_spirv_array (context , builder , ty_array , ty_shape , shape , data )
0 commit comments