2
2
#
3
3
# SPDX-License-Identifier: Apache-2.0
4
4
5
+ from functools import partial
6
+
5
7
import dpnp
6
8
from numba .core import ir , types
7
9
from numba .core .ir_utils import get_np_ufunc_typ , mk_unique_var
8
- from numba .core .pythonapi import NativeValue , PythonAPI , box , unbox
9
10
10
11
from .usm_ndarray_type import USMNdArray
11
12
12
13
14
+ def partialclass (cls , * args , ** kwds ):
15
+ """Creates fabric class of the original class with preset initialization
16
+ arguments."""
17
+ cls0 = partial (cls , * args , ** kwds )
18
+ new_cls = type (
19
+ cls .__name__ + "Partial" ,
20
+ (cls ,),
21
+ {"__new__" : lambda cls , * args , ** kwds : cls0 (* args , ** kwds )},
22
+ )
23
+
24
+ return new_cls
25
+
26
+
13
27
class DpnpNdArray (USMNdArray ):
14
28
"""
15
29
The Numba type to represent an dpnp.ndarray. The type has the same
@@ -40,15 +54,22 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
40
54
Returns: The DpnpNdArray class.
41
55
"""
42
56
if method == "__call__" :
43
- if not all (
44
- (
45
- isinstance (inp , DpnpNdArray )
46
- or isinstance (inp , types .abstract .Number )
47
- )
48
- for inp in inputs
49
- ):
57
+ dpnp_type = None
58
+
59
+ for inp in inputs :
60
+ if isinstance (inp , DpnpNdArray ):
61
+ dpnp_type = inp
62
+ continue
63
+ if isinstance (inp , types .abstract .Number ):
64
+ continue
65
+
50
66
return NotImplemented
51
- return DpnpNdArray
67
+
68
+ assert dpnp_type is not None
69
+
70
+ return partialclass (
71
+ DpnpNdArray , queue = dpnp_type .queue , usm_type = dpnp_type .usm_type
72
+ )
52
73
else :
53
74
return
54
75
@@ -71,6 +92,8 @@ def __allocate__(
71
92
lhs_typ ,
72
93
size_typ ,
73
94
out ,
95
+ # dpex specific argument:
96
+ queue_ir_val = None ,
74
97
):
75
98
"""Generates the Numba typed IR representing the allocation of a new
76
99
DpnpNdArray using the dpnp.ndarray overload.
@@ -94,6 +117,10 @@ def __allocate__(
94
117
95
118
Returns: The IR Value for the allocated array
96
119
"""
120
+ # TODO: it looks like it is being called only for parfor allocations,
121
+ # so can we rely on it? We can grab information from input arguments
122
+ # from rhs, but doc does not set any restriction on parfor use only.
123
+ assert queue_ir_val is not None
97
124
g_np_var = ir .Var (scope , mk_unique_var ("$np_g_var" ), loc )
98
125
if typemap :
99
126
typemap [g_np_var .name ] = types .misc .Module (dpnp )
@@ -132,11 +159,13 @@ def __allocate__(
132
159
usm_typ_var = ir .Var (scope , mk_unique_var ("$np_usm_type_var" ), loc )
133
160
# A default device string arg added as a placeholder
134
161
device_typ_var = ir .Var (scope , mk_unique_var ("$np_device_var" ), loc )
162
+ queue_typ_var = ir .Var (scope , mk_unique_var ("$np_queue_var" ), loc )
135
163
136
164
if typemap :
137
165
typemap [layout_var .name ] = types .literal (lhs_typ .layout )
138
166
typemap [usm_typ_var .name ] = types .literal (lhs_typ .usm_type )
139
- typemap [device_typ_var .name ] = types .literal (lhs_typ .device )
167
+ typemap [device_typ_var .name ] = types .none
168
+ typemap [queue_typ_var .name ] = lhs_typ .queue
140
169
141
170
layout_var_assign = ir .Assign (
142
171
ir .Const (lhs_typ .layout , loc ), layout_var , loc
@@ -145,16 +174,29 @@ def __allocate__(
145
174
ir .Const (lhs_typ .usm_type , loc ), usm_typ_var , loc
146
175
)
147
176
device_typ_var_assign = ir .Assign (
148
- ir .Const (lhs_typ . device , loc ), device_typ_var , loc
177
+ ir .Const (None , loc ), device_typ_var , loc
149
178
)
179
+ queue_typ_var_assign = ir .Assign (queue_ir_val , queue_typ_var , loc )
150
180
151
181
out .extend (
152
- [layout_var_assign , usm_typ_var_assign , device_typ_var_assign ]
182
+ [
183
+ layout_var_assign ,
184
+ usm_typ_var_assign ,
185
+ device_typ_var_assign ,
186
+ queue_typ_var_assign ,
187
+ ]
153
188
)
154
189
155
190
alloc_call = ir .Expr .call (
156
191
attr_var ,
157
- [size_var , typ_var , layout_var , device_typ_var , usm_typ_var ],
192
+ [
193
+ size_var ,
194
+ typ_var ,
195
+ layout_var ,
196
+ device_typ_var ,
197
+ usm_typ_var ,
198
+ queue_typ_var ,
199
+ ],
158
200
(),
159
201
loc ,
160
202
)
@@ -170,6 +212,7 @@ def __allocate__(
170
212
layout_var ,
171
213
device_typ_var ,
172
214
usm_typ_var ,
215
+ queue_typ_var ,
173
216
]
174
217
],
175
218
{},
0 commit comments