Skip to content

Commit 06af70b

Browse files
committed
Use argument queue_ref for empty array allocation for parfor
1 parent a60d032 commit 06af70b

File tree

2 files changed

+157
-14
lines changed

2 files changed

+157
-14
lines changed

numba_dpex/core/parfors/parfor_pass.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from numba.core import config, errors, ir, types, typing
1818
from numba.core.compiler_machinery import register_pass
1919
from numba.core.ir_utils import (
20+
convert_size_to_var,
2021
dprint_func_ir,
21-
mk_alloc,
2222
mk_unique_var,
2323
next_label,
2424
)
@@ -43,6 +43,7 @@
4343
)
4444
from numba.stencils.stencilparfor import StencilPass
4545

46+
from numba_dpex.core.types.dpnp_ndarray_type import DpnpNdArray
4647
from numba_dpex.core.typing import dpnpdecl
4748

4849

@@ -58,6 +59,37 @@ class ConvertDPNPPass(ConvertNumpyPass):
5859
def __init__(self, pass_states):
5960
super().__init__(pass_states)
6061

62+
def _get_queue(self, queue_type, expr: tuple):
63+
"""
64+
Extracts queue from the input arguments of the array operation.
65+
"""
66+
pass_states = self.pass_states
67+
typemap: map[str, any] = pass_states.typemap
68+
69+
var_with_queue = None
70+
71+
for var in expr[1]:
72+
if isinstance(var, tuple):
73+
res = self._get_queue(queue_type, var)
74+
if res is not None:
75+
return res
76+
77+
continue
78+
79+
if not isinstance(var, ir.Var):
80+
continue
81+
82+
_type = typemap[var.name]
83+
if not isinstance(_type, DpnpNdArray):
84+
continue
85+
if queue_type != _type.queue:
86+
continue
87+
88+
var_with_queue = var
89+
break
90+
91+
return ir.Expr.getattr(var_with_queue, "sycl_queue", var_with_queue.loc)
92+
6193
def _arrayexpr_to_parfor(self, equiv_set, lhs, arrayexpr, avail_vars):
6294
"""generate parfor from arrayexpr node, which is essentially a
6395
map with recursive tree.
@@ -77,6 +109,10 @@ def _arrayexpr_to_parfor(self, equiv_set, lhs, arrayexpr, avail_vars):
77109
pass_states.typemap, size_vars, scope, loc
78110
)
79111

112+
# Expr is a tuple
113+
ir_queue = self._get_queue(arr_typ.queue, expr)
114+
assert ir_queue is not None
115+
80116
# generate init block and body
81117
init_block = ir.Block(scope, loc)
82118
init_block.body = mk_alloc(
@@ -89,6 +125,7 @@ def _arrayexpr_to_parfor(self, equiv_set, lhs, arrayexpr, avail_vars):
89125
scope,
90126
loc,
91127
pass_states.typemap[lhs.name],
128+
queue_ir_val=ir_queue,
92129
)
93130
body_label = next_label()
94131
body_block = ir.Block(scope, loc)
@@ -469,3 +506,66 @@ def _arrayexpr_tree_to_ir(
469506
typemap.pop(expr_out_var.name, None)
470507
typemap[expr_out_var.name] = el_typ
471508
return out_ir
509+
510+
511+
def mk_alloc(
512+
typingctx,
513+
typemap,
514+
calltypes,
515+
lhs,
516+
size_var,
517+
dtype,
518+
scope,
519+
loc,
520+
lhs_typ,
521+
**kws,
522+
):
523+
"""generate an array allocation with np.empty() and return list of nodes.
524+
size_var can be an int variable or tuple of int variables.
525+
lhs_typ is the type of the array being allocated.
526+
527+
Taken from numba, added kws argument to pass it to __allocate__
528+
"""
529+
out = []
530+
ndims = 1
531+
size_typ = types.intp
532+
if isinstance(size_var, tuple):
533+
if len(size_var) == 1:
534+
size_var = size_var[0]
535+
size_var = convert_size_to_var(size_var, typemap, scope, loc, out)
536+
else:
537+
# tuple_var = build_tuple([size_var...])
538+
ndims = len(size_var)
539+
tuple_var = ir.Var(scope, mk_unique_var("$tuple_var"), loc)
540+
if typemap:
541+
typemap[tuple_var.name] = types.containers.UniTuple(
542+
types.intp, ndims
543+
)
544+
# constant sizes need to be assigned to vars
545+
new_sizes = [
546+
convert_size_to_var(s, typemap, scope, loc, out)
547+
for s in size_var
548+
]
549+
tuple_call = ir.Expr.build_tuple(new_sizes, loc)
550+
tuple_assign = ir.Assign(tuple_call, tuple_var, loc)
551+
out.append(tuple_assign)
552+
size_var = tuple_var
553+
size_typ = types.containers.UniTuple(types.intp, ndims)
554+
if hasattr(lhs_typ, "__allocate__"):
555+
return lhs_typ.__allocate__(
556+
typingctx,
557+
typemap,
558+
calltypes,
559+
lhs,
560+
size_var,
561+
dtype,
562+
scope,
563+
loc,
564+
lhs_typ,
565+
size_typ,
566+
out,
567+
**kws,
568+
)
569+
570+
# Unused numba's code..
571+
assert False

numba_dpex/core/types/dpnp_ndarray_type.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,28 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from functools import partial
6+
57
import dpnp
68
from numba.core import ir, types
79
from numba.core.ir_utils import get_np_ufunc_typ, mk_unique_var
8-
from numba.core.pythonapi import NativeValue, PythonAPI, box, unbox
910

1011
from .usm_ndarray_type import USMNdArray
1112

1213

14+
def partialclass(cls, *args, **kwds):
15+
"""Creates fabric class of the original class with preset initialization
16+
arguments."""
17+
cls0 = partial(cls, *args, **kwds)
18+
new_cls = type(
19+
cls.__name__ + "Partial",
20+
(cls,),
21+
{"__new__": lambda cls, *args, **kwds: cls0(*args, **kwds)},
22+
)
23+
24+
return new_cls
25+
26+
1327
class DpnpNdArray(USMNdArray):
1428
"""
1529
The Numba type to represent an dpnp.ndarray. The type has the same
@@ -40,15 +54,22 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
4054
Returns: The DpnpNdArray class.
4155
"""
4256
if method == "__call__":
43-
if not all(
44-
(
45-
isinstance(inp, DpnpNdArray)
46-
or isinstance(inp, types.abstract.Number)
47-
)
48-
for inp in inputs
49-
):
57+
dpnp_type = None
58+
59+
for inp in inputs:
60+
if isinstance(inp, DpnpNdArray):
61+
dpnp_type = inp
62+
continue
63+
if isinstance(inp, types.abstract.Number):
64+
continue
65+
5066
return NotImplemented
51-
return DpnpNdArray
67+
68+
assert dpnp_type is not None
69+
70+
return partialclass(
71+
DpnpNdArray, queue=dpnp_type.queue, usm_type=dpnp_type.usm_type
72+
)
5273
else:
5374
return
5475

@@ -71,6 +92,8 @@ def __allocate__(
7192
lhs_typ,
7293
size_typ,
7394
out,
95+
# dpex specific argument:
96+
queue_ir_val=None,
7497
):
7598
"""Generates the Numba typed IR representing the allocation of a new
7699
DpnpNdArray using the dpnp.ndarray overload.
@@ -94,6 +117,10 @@ def __allocate__(
94117
95118
Returns: The IR Value for the allocated array
96119
"""
120+
# TODO: it looks like it is being called only for parfor allocations,
121+
# so can we rely on it? We can grab information from input arguments
122+
# from rhs, but doc does not set any restriction on parfor use only.
123+
assert queue_ir_val is not None
97124
g_np_var = ir.Var(scope, mk_unique_var("$np_g_var"), loc)
98125
if typemap:
99126
typemap[g_np_var.name] = types.misc.Module(dpnp)
@@ -132,11 +159,13 @@ def __allocate__(
132159
usm_typ_var = ir.Var(scope, mk_unique_var("$np_usm_type_var"), loc)
133160
# A default device string arg added as a placeholder
134161
device_typ_var = ir.Var(scope, mk_unique_var("$np_device_var"), loc)
162+
queue_typ_var = ir.Var(scope, mk_unique_var("$np_queue_var"), loc)
135163

136164
if typemap:
137165
typemap[layout_var.name] = types.literal(lhs_typ.layout)
138166
typemap[usm_typ_var.name] = types.literal(lhs_typ.usm_type)
139-
typemap[device_typ_var.name] = types.literal(lhs_typ.device)
167+
typemap[device_typ_var.name] = types.none
168+
typemap[queue_typ_var.name] = lhs_typ.queue
140169

141170
layout_var_assign = ir.Assign(
142171
ir.Const(lhs_typ.layout, loc), layout_var, loc
@@ -145,16 +174,29 @@ def __allocate__(
145174
ir.Const(lhs_typ.usm_type, loc), usm_typ_var, loc
146175
)
147176
device_typ_var_assign = ir.Assign(
148-
ir.Const(lhs_typ.device, loc), device_typ_var, loc
177+
ir.Const(None, loc), device_typ_var, loc
149178
)
179+
queue_typ_var_assign = ir.Assign(queue_ir_val, queue_typ_var, loc)
150180

151181
out.extend(
152-
[layout_var_assign, usm_typ_var_assign, device_typ_var_assign]
182+
[
183+
layout_var_assign,
184+
usm_typ_var_assign,
185+
device_typ_var_assign,
186+
queue_typ_var_assign,
187+
]
153188
)
154189

155190
alloc_call = ir.Expr.call(
156191
attr_var,
157-
[size_var, typ_var, layout_var, device_typ_var, usm_typ_var],
192+
[
193+
size_var,
194+
typ_var,
195+
layout_var,
196+
device_typ_var,
197+
usm_typ_var,
198+
queue_typ_var,
199+
],
158200
(),
159201
loc,
160202
)
@@ -170,6 +212,7 @@ def __allocate__(
170212
layout_var,
171213
device_typ_var,
172214
usm_typ_var,
215+
queue_typ_var,
173216
]
174217
],
175218
{},

0 commit comments

Comments
 (0)