Skip to content

Commit 2dce65a

Browse files
committed
Move numba/np parfor patch to dpnp parfor
1 parent 0228db3 commit 2dce65a

File tree

2 files changed

+188
-199
lines changed

2 files changed

+188
-199
lines changed

numba_dpex/core/parfors/parfor_pass.py

Lines changed: 188 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@
77
to use it with dpnp instead of numpy.
88
"""
99

10+
11+
import copy
12+
import math
13+
import operator
1014
import warnings
1115

12-
from numba.core import config, errors, ir, types
16+
from numba.core import config, errors, ir, types, typing
1317
from numba.core.compiler_machinery import register_pass
1418
from numba.core.ir_utils import (
1519
dprint_func_ir,
@@ -19,6 +23,8 @@
1923
)
2024
from numba.core.typed_passes import ParforPass as NumpyParforPass
2125
from numba.core.typed_passes import _reload_parfors
26+
from numba.core.typing import npydecl
27+
from numba.parfors import array_analysis, parfor
2228
from numba.parfors.parfor import (
2329
ConvertInplaceBinop,
2430
ConvertLoopPass,
@@ -36,10 +42,6 @@
3642
)
3743
from numba.stencils.stencilparfor import StencilPass
3844

39-
from numba_dpex.numba_patches.patch_arrayexpr_tree_to_ir import (
40-
_arrayexpr_tree_to_ir,
41-
)
42-
4345

4446
class ConvertDPNPPass(ConvertNumpyPass):
4547
"""
@@ -249,3 +251,184 @@ def run_pass(self, state):
249251
# Add reload function to initialize the parallel backend.
250252
state.reload_init.append(_reload_parfors)
251253
return True
254+
255+
256+
def _ufunc_to_parfor_instr(
257+
typemap,
258+
op,
259+
avail_vars,
260+
loc,
261+
scope,
262+
func_ir,
263+
out_ir,
264+
arg_vars,
265+
typingctx,
266+
calltypes,
267+
expr_out_var,
268+
):
269+
func_var_name = parfor._find_func_var(typemap, op, avail_vars, loc=loc)
270+
func_var = ir.Var(scope, mk_unique_var(func_var_name), loc)
271+
typemap[func_var.name] = typemap[func_var_name]
272+
func_var_def = copy.deepcopy(func_ir.get_definition(func_var_name))
273+
if (
274+
isinstance(func_var_def, ir.Expr)
275+
and func_var_def.op == "getattr"
276+
and func_var_def.attr == "sqrt"
277+
):
278+
g_math_var = ir.Var(scope, mk_unique_var("$math_g_var"), loc)
279+
typemap[g_math_var.name] = types.misc.Module(math)
280+
g_math = ir.Global("math", math, loc)
281+
g_math_assign = ir.Assign(g_math, g_math_var, loc)
282+
func_var_def = ir.Expr.getattr(g_math_var, "sqrt", loc)
283+
out_ir.append(g_math_assign)
284+
ir_expr = ir.Expr.call(func_var, arg_vars, (), loc)
285+
call_typ = typemap[func_var.name].get_call_type(
286+
typingctx, tuple(typemap[a.name] for a in arg_vars), {}
287+
)
288+
calltypes[ir_expr] = call_typ
289+
el_typ = call_typ.return_type
290+
out_ir.append(ir.Assign(func_var_def, func_var, loc))
291+
out_ir.append(ir.Assign(ir_expr, expr_out_var, loc))
292+
293+
return el_typ
294+
295+
296+
def _arrayexpr_tree_to_ir(
297+
func_ir,
298+
typingctx,
299+
typemap,
300+
calltypes,
301+
equiv_set,
302+
init_block,
303+
expr_out_var,
304+
expr,
305+
parfor_index_tuple_var,
306+
all_parfor_indices,
307+
avail_vars,
308+
):
309+
"""generate IR from array_expr's expr tree recursively. Assign output to
310+
expr_out_var and returns the whole IR as a list of Assign nodes.
311+
"""
312+
el_typ = typemap[expr_out_var.name]
313+
scope = expr_out_var.scope
314+
loc = expr_out_var.loc
315+
out_ir = []
316+
317+
if isinstance(expr, tuple):
318+
op, arr_expr_args = expr
319+
arg_vars = []
320+
for arg in arr_expr_args:
321+
arg_out_var = ir.Var(scope, mk_unique_var("$arg_out_var"), loc)
322+
typemap[arg_out_var.name] = el_typ
323+
out_ir += _arrayexpr_tree_to_ir(
324+
func_ir,
325+
typingctx,
326+
typemap,
327+
calltypes,
328+
equiv_set,
329+
init_block,
330+
arg_out_var,
331+
arg,
332+
parfor_index_tuple_var,
333+
all_parfor_indices,
334+
avail_vars,
335+
)
336+
arg_vars.append(arg_out_var)
337+
if op in npydecl.supported_array_operators:
338+
el_typ1 = typemap[arg_vars[0].name]
339+
if len(arg_vars) == 2:
340+
el_typ2 = typemap[arg_vars[1].name]
341+
func_typ = typingctx.resolve_function_type(
342+
op, (el_typ1, el_typ2), {}
343+
)
344+
ir_expr = ir.Expr.binop(op, arg_vars[0], arg_vars[1], loc)
345+
if op == operator.truediv:
346+
func_typ, ir_expr = parfor._gen_np_divide(
347+
arg_vars[0], arg_vars[1], out_ir, typemap
348+
)
349+
else:
350+
func_typ = typingctx.resolve_function_type(op, (el_typ1,), {})
351+
ir_expr = ir.Expr.unary(op, arg_vars[0], loc)
352+
calltypes[ir_expr] = func_typ
353+
el_typ = func_typ.return_type
354+
out_ir.append(ir.Assign(ir_expr, expr_out_var, loc))
355+
for T in array_analysis.MAP_TYPES:
356+
if isinstance(op, T):
357+
# function calls are stored in variables which are not removed
358+
# op is typing_key to the variables type
359+
func_var_name = parfor._find_func_var(
360+
typemap, op, avail_vars, loc=loc
361+
)
362+
func_var = ir.Var(scope, mk_unique_var(func_var_name), loc)
363+
typemap[func_var.name] = typemap[func_var_name]
364+
func_var_def = copy.deepcopy(
365+
func_ir.get_definition(func_var_name)
366+
)
367+
if (
368+
isinstance(func_var_def, ir.Expr)
369+
and func_var_def.op == "getattr"
370+
and func_var_def.attr == "sqrt"
371+
):
372+
g_math_var = ir.Var(
373+
scope, mk_unique_var("$math_g_var"), loc
374+
)
375+
typemap[g_math_var.name] = types.misc.Module(math)
376+
g_math = ir.Global("math", math, loc)
377+
g_math_assign = ir.Assign(g_math, g_math_var, loc)
378+
func_var_def = ir.Expr.getattr(g_math_var, "sqrt", loc)
379+
out_ir.append(g_math_assign)
380+
ir_expr = ir.Expr.call(func_var, arg_vars, (), loc)
381+
call_typ = typemap[func_var.name].get_call_type(
382+
typingctx, tuple(typemap[a.name] for a in arg_vars), {}
383+
)
384+
calltypes[ir_expr] = call_typ
385+
el_typ = call_typ.return_type
386+
out_ir.append(ir.Assign(func_var_def, func_var, loc))
387+
out_ir.append(ir.Assign(ir_expr, expr_out_var, loc))
388+
# NUMBA_DPEX: is_dpnp_func check was added
389+
if hasattr(op, "is_dpnp_ufunc"):
390+
el_typ = _ufunc_to_parfor_instr(
391+
typemap,
392+
op,
393+
avail_vars,
394+
loc,
395+
scope,
396+
func_ir,
397+
out_ir,
398+
arg_vars,
399+
typingctx,
400+
calltypes,
401+
expr_out_var,
402+
)
403+
elif isinstance(expr, ir.Var):
404+
var_typ = typemap[expr.name]
405+
if isinstance(var_typ, types.Array):
406+
el_typ = var_typ.dtype
407+
ir_expr = parfor._gen_arrayexpr_getitem(
408+
equiv_set,
409+
expr,
410+
parfor_index_tuple_var,
411+
all_parfor_indices,
412+
el_typ,
413+
calltypes,
414+
typingctx,
415+
typemap,
416+
init_block,
417+
out_ir,
418+
)
419+
else:
420+
el_typ = var_typ
421+
ir_expr = expr
422+
out_ir.append(ir.Assign(ir_expr, expr_out_var, loc))
423+
elif isinstance(expr, ir.Const):
424+
el_typ = typing.Context().resolve_value_type(expr.value)
425+
out_ir.append(ir.Assign(expr, expr_out_var, loc))
426+
427+
if len(out_ir) == 0:
428+
raise errors.UnsupportedRewriteError(
429+
f"Don't know how to translate array expression '{expr:r}'",
430+
loc=expr.loc,
431+
)
432+
typemap.pop(expr_out_var.name, None)
433+
typemap[expr_out_var.name] = el_typ
434+
return out_ir

0 commit comments

Comments
 (0)