Skip to content

Commit e0639f0

Browse files
committed
Upgrade patch_arrayexpr_tree_to_ir
1 parent 2ea273e commit e0639f0

File tree

1 file changed

+29
-12
lines changed

1 file changed

+29
-12
lines changed

numba_dpex/numba_patches/patch_arrayexpr_tree_to_ir.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,19 +116,36 @@ def _arrayexpr_tree_to_ir(
116116
if isinstance(op, T):
117117
# function calls are stored in variables which are not removed
118118
# op is typing_key to the variables type
119-
el_typ = _ufunc_to_parfor_instr(
120-
typemap,
121-
op,
122-
avail_vars,
123-
loc,
124-
scope,
125-
func_ir,
126-
out_ir,
127-
arg_vars,
128-
typingctx,
129-
calltypes,
130-
expr_out_var,
119+
func_var_name = parfor._find_func_var(
120+
typemap, op, avail_vars, loc=loc
131121
)
122+
func_var = ir.Var(scope, mk_unique_var(func_var_name), loc)
123+
typemap[func_var.name] = typemap[func_var_name]
124+
func_var_def = copy.deepcopy(
125+
func_ir.get_definition(func_var_name)
126+
)
127+
if (
128+
isinstance(func_var_def, ir.Expr)
129+
and func_var_def.op == "getattr"
130+
and func_var_def.attr == "sqrt"
131+
):
132+
g_math_var = ir.Var(
133+
scope, mk_unique_var("$math_g_var"), loc
134+
)
135+
typemap[g_math_var.name] = types.misc.Module(math)
136+
g_math = ir.Global("math", math, loc)
137+
g_math_assign = ir.Assign(g_math, g_math_var, loc)
138+
func_var_def = ir.Expr.getattr(g_math_var, "sqrt", loc)
139+
out_ir.append(g_math_assign)
140+
ir_expr = ir.Expr.call(func_var, arg_vars, (), loc)
141+
call_typ = typemap[func_var.name].get_call_type(
142+
typingctx, tuple(typemap[a.name] for a in arg_vars), {}
143+
)
144+
calltypes[ir_expr] = call_typ
145+
el_typ = call_typ.return_type
146+
out_ir.append(ir.Assign(func_var_def, func_var, loc))
147+
out_ir.append(ir.Assign(ir_expr, expr_out_var, loc))
148+
# NUMBA_DPEX: is_dpnp_func check was added
132149
if hasattr(op, "is_dpnp_ufunc"):
133150
el_typ = _ufunc_to_parfor_instr(
134151
typemap,

0 commit comments

Comments
 (0)