Skip to content

Commit 7804da8

Browse files
committed
Add dpnp/dpjit specific parfor
1 parent e0639f0 commit 7804da8

File tree

2 files changed

+252
-1
lines changed

2 files changed

+252
-1
lines changed
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
# SPDX-FileCopyrightText: 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""
6+
This module follows the logic of numba/parfors/parfor.py with changes required
7+
to use it with dpnp instead of numpy.
8+
"""
9+
10+
import warnings
11+
12+
from numba.core import config, errors, ir, types
13+
from numba.core.compiler_machinery import register_pass
14+
from numba.core.ir_utils import (
15+
dprint_func_ir,
16+
mk_alloc,
17+
mk_unique_var,
18+
next_label,
19+
)
20+
from numba.core.typed_passes import ParforPass as NumpyParforPass
21+
from numba.core.typed_passes import _reload_parfors
22+
from numba.parfors.parfor import (
23+
ConvertInplaceBinop,
24+
ConvertLoopPass,
25+
ConvertNumpyPass,
26+
ConvertReducePass,
27+
ConvertSetItemPass,
28+
Parfor,
29+
)
30+
from numba.parfors.parfor import ParforPass as _NumpyParforPass
31+
from numba.parfors.parfor import (
32+
_make_index_var,
33+
_mk_parfor_loops,
34+
repr_arrayexpr,
35+
signature,
36+
)
37+
from numba.stencils.stencilparfor import StencilPass
38+
39+
from numba_dpex.numba_patches.patch_arrayexpr_tree_to_ir import (
40+
_arrayexpr_tree_to_ir,
41+
)
42+
43+
44+
class ConvertDPNPPass(ConvertNumpyPass):
45+
"""
46+
Convert supported Dpnp functions, as well as arrayexpr nodes, to
47+
parfor nodes.
48+
49+
Based on the ConvertNumpyPass. Lot's of code was copy-pasted, with minor
50+
changes due to lack of extensibility of the original package.
51+
"""
52+
53+
def __init__(self, pass_states):
54+
super().__init__(pass_states)
55+
56+
def _arrayexpr_to_parfor(self, equiv_set, lhs, arrayexpr, avail_vars):
57+
"""generate parfor from arrayexpr node, which is essentially a
58+
map with recursive tree.
59+
60+
Exactly same as the original one, but with mock to _arrayexpr_tree_to_ir
61+
"""
62+
pass_states = self.pass_states
63+
scope = lhs.scope
64+
loc = lhs.loc
65+
expr = arrayexpr.expr
66+
arr_typ = pass_states.typemap[lhs.name]
67+
el_typ = arr_typ.dtype
68+
69+
# generate loopnests and size variables from lhs correlations
70+
size_vars = equiv_set.get_shape(lhs)
71+
index_vars, loopnests = _mk_parfor_loops(
72+
pass_states.typemap, size_vars, scope, loc
73+
)
74+
75+
# generate init block and body
76+
init_block = ir.Block(scope, loc)
77+
init_block.body = mk_alloc(
78+
pass_states.typingctx,
79+
pass_states.typemap,
80+
pass_states.calltypes,
81+
lhs,
82+
tuple(size_vars),
83+
el_typ,
84+
scope,
85+
loc,
86+
pass_states.typemap[lhs.name],
87+
)
88+
body_label = next_label()
89+
body_block = ir.Block(scope, loc)
90+
expr_out_var = ir.Var(scope, mk_unique_var("$expr_out_var"), loc)
91+
pass_states.typemap[expr_out_var.name] = el_typ
92+
93+
index_var, index_var_typ = _make_index_var(
94+
pass_states.typemap, scope, index_vars, body_block
95+
)
96+
97+
body_block.body.extend(
98+
_arrayexpr_tree_to_ir(
99+
pass_states.func_ir,
100+
pass_states.typingctx,
101+
pass_states.typemap,
102+
pass_states.calltypes,
103+
equiv_set,
104+
init_block,
105+
expr_out_var,
106+
expr,
107+
index_var,
108+
index_vars,
109+
avail_vars,
110+
)
111+
)
112+
113+
pat = ("array expression {}".format(repr_arrayexpr(arrayexpr.expr)),)
114+
115+
parfor = Parfor(
116+
loopnests,
117+
init_block,
118+
{},
119+
loc,
120+
index_var,
121+
equiv_set,
122+
pat[0],
123+
pass_states.flags,
124+
)
125+
126+
setitem_node = ir.SetItem(lhs, index_var, expr_out_var, loc)
127+
pass_states.calltypes[setitem_node] = signature(
128+
types.none, pass_states.typemap[lhs.name], index_var_typ, el_typ
129+
)
130+
body_block.body.append(setitem_node)
131+
parfor.loop_body = {body_label: body_block}
132+
if config.DEBUG_ARRAY_OPT >= 1:
133+
print("parfor from arrayexpr")
134+
parfor.dump()
135+
return parfor
136+
137+
138+
class _ParforPass(_NumpyParforPass):
139+
"""ParforPass class is responsible for converting NumPy
140+
calls in Numba intermediate representation to Parfors, which
141+
will lower into either sequential or parallel loops during lowering
142+
stage.
143+
144+
Based on the _NumpyParforPass. Lot's of code was copy-pasted, with minor
145+
changes due to lack of extensibility of the original package.
146+
"""
147+
148+
def run(self):
149+
"""run parfor conversion pass: replace Numpy calls
150+
with Parfors when possible and optimize the IR.
151+
152+
Exactly same as the original one, but with mock ConvertNumpyPass to
153+
ConvertDPNPPass.
154+
"""
155+
self._pre_run()
156+
# run stencil translation to parfor
157+
if self.options.stencil:
158+
stencil_pass = StencilPass(
159+
self.func_ir,
160+
self.typemap,
161+
self.calltypes,
162+
self.array_analysis,
163+
self.typingctx,
164+
self.targetctx,
165+
self.flags,
166+
)
167+
stencil_pass.run()
168+
if self.options.setitem:
169+
ConvertSetItemPass(self).run(self.func_ir.blocks)
170+
if self.options.numpy:
171+
ConvertDPNPPass(self).run(self.func_ir.blocks)
172+
if self.options.reduction:
173+
ConvertReducePass(self).run(self.func_ir.blocks)
174+
if self.options.prange:
175+
ConvertLoopPass(self).run(self.func_ir.blocks)
176+
if self.options.inplace_binop:
177+
ConvertInplaceBinop(self).run(self.func_ir.blocks)
178+
179+
# setup diagnostics now parfors are found
180+
self.diagnostics.setup(self.func_ir, self.options.fusion)
181+
182+
dprint_func_ir(self.func_ir, "after parfor pass")
183+
184+
185+
@register_pass(mutates_CFG=True, analysis_only=False)
186+
class ParforPass(NumpyParforPass):
187+
"""Based on the NumpyParforPass. Lot's of code was copy-pasted, with minor
188+
changes due to lack of extensibility of the original package.
189+
"""
190+
191+
_name = "dpnp_parfor_pass"
192+
193+
def __init__(self):
194+
NumpyParforPass.__init__(self)
195+
196+
def run_pass(self, state):
197+
"""
198+
Convert data-parallel computations into Parfor nodes.
199+
200+
Exactly same as the original one, but with mock to _ParforPass.
201+
"""
202+
# Ensure we have an IR and type information.
203+
assert state.func_ir
204+
parfor_pass = _ParforPass(
205+
state.func_ir,
206+
state.typemap,
207+
state.calltypes,
208+
state.return_type,
209+
state.typingctx,
210+
state.targetctx,
211+
state.flags.auto_parallel,
212+
state.flags,
213+
state.metadata,
214+
state.parfor_diagnostics,
215+
)
216+
parfor_pass.run()
217+
218+
# check the parfor pass worked and warn if it didn't
219+
has_parfor = False
220+
for blk in state.func_ir.blocks.values():
221+
for stmnt in blk.body:
222+
if isinstance(stmnt, Parfor):
223+
has_parfor = True
224+
break
225+
else:
226+
continue
227+
break
228+
229+
if not has_parfor:
230+
# parfor calls the compiler chain again with a string
231+
if not (
232+
config.DISABLE_PERFORMANCE_WARNINGS
233+
or state.func_ir.loc.filename == "<string>"
234+
):
235+
url = (
236+
"https://numba.readthedocs.io/en/stable/user/"
237+
"parallel.html#diagnostics"
238+
)
239+
msg = (
240+
"\nThe keyword argument 'parallel=True' was specified "
241+
"but no transformation for parallel execution was "
242+
"possible.\n\nTo find out why, try turning on parallel "
243+
"diagnostics, see %s for help." % url
244+
)
245+
warnings.warn(
246+
errors.NumbaPerformanceWarning(msg, state.func_ir.loc)
247+
)
248+
249+
# Add reload function to initialize the parallel backend.
250+
state.reload_init.append(_reload_parfors)
251+
return True

numba_dpex/core/pipelines/dpjit_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
NoPythonSupportedFeatureValidation,
1414
NopythonTypeInference,
1515
ParforFusionPass,
16-
ParforPass,
1716
ParforPreLoweringPass,
1817
PreLowerStripPhis,
1918
PreParforPass,
2019
)
2120

2221
from numba_dpex.core.exceptions import UnsupportedCompilationModeError
22+
from numba_dpex.core.parfors.parfor_pass import ParforPass
2323
from numba_dpex.core.passes import (
2424
DumpParforDiagnostics,
2525
NoPythonBackend,

0 commit comments

Comments
 (0)