1+ # SPDX-FileCopyrightText: 2012 - 2024 Anaconda Inc.
12# SPDX-FileCopyrightText: 2024 Intel Corporation
23#
34# SPDX-License-Identifier: Apache-2.0
5+ # SPDX-License-Identifier: BSD-2-Clause
46
57"""Contains SPIR-V specific array functions."""
68
7- import operator
8- from functools import reduce
99from typing import Union
1010
1111import llvmlite .ir as llvmir
1212from llvmlite .ir .builder import IRBuilder
1313from numba .core import cgutils , errors , types
1414from numba .core .base import BaseContext
15+ from numba .np .arrayobj import (
16+ basic_indexing ,
17+ get_itemsize ,
18+ load_item ,
19+ make_array ,
20+ )
1521
16- from numba_dpex .kernel_api_impl .spirv .target import SPIRVTargetContext
17- from numba_dpex .ocl .oclimpl import _get_target_data
22+ from numba_dpex .core .types import USMNdArray
1823
1924
20- def get_itemsize (context : SPIRVTargetContext , array_type : types .Array ):
25+ def populate_array (
26+ arraystruct , data , shape , strides , itemsize
27+ ): # pylint: disable=too-many-arguments,too-many-locals
2128 """
22- Return the item size for the given array or buffer type.
23- Same as numba.np.arrayobj.get_itemsize, but using spirv data.
29+ Helper function for populating array structures.
30+
31+ The function is copied from upstream Numba and modified to support the
32+ USMNdArray data type that uses a different data model on SYCL devices
33+ than the upstream types.Array data type. USMNdArray data model does not
34+ have the ``parent`` and ``meminfo`` fields. This function intended to be
35+ used only in the SPIRVKernelTarget.
36+
37+ *shape* and *strides* can be Python tuples or LLVM arrays.
38+ """
39+ context = arraystruct ._context # pylint: disable=protected-access
40+ builder = arraystruct ._builder # pylint: disable=protected-access
41+ datamodel = arraystruct ._datamodel # pylint: disable=protected-access
42+ # doesn't matter what this array type instance is, it's just to get the
43+ # fields for the data model of the standard array type in this context
44+ standard_array = USMNdArray (ndim = 1 , layout = "C" , dtype = types .float64 )
45+ standard_array_type_datamodel = context .data_model_manager [standard_array ]
46+ required_fields = set (standard_array_type_datamodel ._fields )
47+ datamodel_fields = set (datamodel ._fields )
48+ # Make sure that the presented array object has a data model that is
49+ # close enough to an array for this function to proceed.
50+ if (required_fields & datamodel_fields ) != required_fields :
51+ missing = required_fields - datamodel_fields
52+ msg = (
53+ f"The datamodel for type { arraystruct } is missing "
54+ f"field{ 's' if len (missing ) > 1 else '' } { missing } ."
55+ )
56+ raise ValueError (msg )
57+
58+ intp_t = context .get_value_type (types .intp )
59+ if isinstance (shape , (tuple , list )):
60+ shape = cgutils .pack_array (builder , shape , intp_t )
61+ if isinstance (strides , (tuple , list )):
62+ strides = cgutils .pack_array (builder , strides , intp_t )
63+ if isinstance (itemsize , int ):
64+ itemsize = intp_t (itemsize )
65+
66+ attrs = {
67+ "shape" : shape ,
68+ "strides" : strides ,
69+ "data" : data ,
70+ "itemsize" : itemsize ,
71+ }
72+
73+ # Calc num of items from shape
74+ nitems = context .get_constant (types .intp , 1 )
75+ unpacked_shape = cgutils .unpack_tuple (builder , shape , shape .type .count )
76+ # (note empty shape => 0d array therefore nitems = 1)
77+ for axlen in unpacked_shape :
78+ nitems = builder .mul (nitems , axlen , flags = ["nsw" ])
79+ attrs ["nitems" ] = nitems
80+
81+ # Make sure that we have all the fields
82+ got_fields = set (attrs .keys ())
83+ if got_fields != required_fields :
84+ raise ValueError (f"missing { required_fields - got_fields } " )
85+
86+ # Set field value
87+ for k , v in attrs .items ():
88+ setattr (arraystruct , k , v )
89+
90+ return arraystruct
91+
92+
93+ def make_view (
94+ context , builder , ary , return_type , data , shapes , strides
95+ ): # pylint: disable=too-many-arguments
96+ """
97+ Build a view over the given array with the given parameters.
98+
99+ This is analog of numpy.np.arrayobj.make_view without parent and
100+ meminfo fields, because they don't make sense on device. This function
101+ intended to be used only in kernel targets.
102+ """
103+ retary = make_array (return_type )(context , builder )
104+ context .populate_array (
105+ retary , data = data , shape = shapes , strides = strides , itemsize = ary .itemsize
106+ )
107+ return retary
108+
109+
110+ def _getitem_array_generic (
111+ context , builder , return_type , aryty , ary , index_types , indices
112+ ): # pylint: disable=too-many-arguments
24113 """
25- targetdata = _get_target_data (context )
26- lldtype = context .get_data_type (array_type .dtype )
27- return lldtype .get_abi_size (targetdata )
114+ Return the result of indexing *ary* with the given *indices*,
115+ returning either a scalar or a view.
116+
117+ This is analog of numpy.np.arrayobj._getitem_array_generic without parent
118+ and meminfo fields, because they don't make sense on device. This function
119+ intended to be used only in kernel targets.
120+ """
121+ dataptr , view_shapes , view_strides = basic_indexing (
122+ context ,
123+ builder ,
124+ aryty ,
125+ ary ,
126+ index_types ,
127+ indices ,
128+ boundscheck = context .enable_boundscheck ,
129+ )
130+
131+ if isinstance (return_type , types .Buffer ):
132+ # Build array view
133+ retary = make_view (
134+ context ,
135+ builder ,
136+ ary ,
137+ return_type ,
138+ dataptr ,
139+ view_shapes ,
140+ view_strides ,
141+ )
142+ return retary ._getvalue () # pylint: disable=protected-access
143+
144+ # Load scalar from 0-d result
145+ assert not view_shapes
146+ return load_item (context , builder , aryty , dataptr )
28147
29148
30149def require_literal (literal_type : types .Type ):
@@ -46,15 +165,22 @@ def require_literal(literal_type: types.Type):
46165 )
47166
48167
49- def make_spirv_array ( # pylint: disable=too-many-arguments
50- context : SPIRVTargetContext ,
168+ def np_cfarray ( # pylint: disable=too-many-arguments
169+ context : BaseContext ,
51170 builder : IRBuilder ,
52171 ty_array : types .Array ,
53172 ty_shape : Union [types .IntegerLiteral , types .BaseTuple ],
54173 shape : llvmir .Value ,
55174 data : llvmir .Value ,
56175):
57- """Makes SPIR-V array and fills it data."""
176+ """Makes numpy-like array and fills it's data depending on the context's
177+ datamodel.
178+
179+ Generic version of numba.np.arrayobj.np_cfarray so that it can be used
180+ not only as intrinsic, but inside instruction generation.
181+
182+ TODO: upstream changes to numba.
183+ """
58184 # Create array object
59185 ary = context .make_array (ty_array )(context , builder )
60186
@@ -92,32 +218,3 @@ def make_spirv_array( # pylint: disable=too-many-arguments
92218 )
93219
94220 return ary
95-
96-
97- def allocate_array_data_on_stack (
98- context : BaseContext ,
99- builder : IRBuilder ,
100- ty_array : types .Array ,
101- ty_shape : Union [types .IntegerLiteral , types .BaseTuple ],
102- ):
103- """Allocates flat array of given shape on the stack."""
104- if not isinstance (ty_shape , types .BaseTuple ):
105- ty_shape = (ty_shape ,)
106-
107- return cgutils .alloca_once (
108- builder ,
109- context .get_data_type (ty_array .dtype ),
110- size = reduce (operator .mul , [s .literal_value for s in ty_shape ]),
111- )
112-
113-
114- def make_spirv_generic_array_on_stack (
115- context : SPIRVTargetContext ,
116- builder : IRBuilder ,
117- ty_array : types .Array ,
118- ty_shape : Union [types .IntegerLiteral , types .BaseTuple ],
119- shape : llvmir .Value ,
120- ):
121- """Makes SPIR-V array of given shape with empty data."""
122- data = allocate_array_data_on_stack (context , builder , ty_array , ty_shape )
123- return make_spirv_array (context , builder , ty_array , ty_shape , shape , data )
0 commit comments