Skip to content

Commit 2ea273e

Browse files
committed
Flatten patch_arrayexpr_tree_to_ir
1 parent 778bde3 commit 2ea273e

File tree

1 file changed

+156
-155
lines changed

1 file changed

+156
-155
lines changed

numba_dpex/numba_patches/patch_arrayexpr_tree_to_ir.py

Lines changed: 156 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -3,139 +3,119 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55

6-
def patch():
7-
"""
8-
Patches the _arrayexpr_tree_to_ir function in numba.parfor.parfor.py to
9-
support array expression nodes that were generated from dpnp expressions.
10-
"""
6+
import copy
7+
import math
8+
import operator
119

12-
import copy
13-
import math
14-
import operator
10+
from numba.core import errors, ir, types, typing
11+
from numba.core.ir_utils import mk_unique_var
12+
from numba.core.typing import npydecl
13+
from numba.parfors import array_analysis, parfor
1514

16-
from numba.core import errors, ir, types, typing
17-
from numba.core.ir_utils import mk_unique_var
18-
from numba.core.typing import npydecl
19-
from numba.parfors import array_analysis, parfor
2015

21-
def _ufunc_to_parfor_instr(
22-
typemap,
23-
op,
24-
avail_vars,
25-
loc,
26-
scope,
27-
func_ir,
28-
out_ir,
29-
arg_vars,
30-
typingctx,
31-
calltypes,
32-
expr_out_var,
16+
def _ufunc_to_parfor_instr(
17+
typemap,
18+
op,
19+
avail_vars,
20+
loc,
21+
scope,
22+
func_ir,
23+
out_ir,
24+
arg_vars,
25+
typingctx,
26+
calltypes,
27+
expr_out_var,
28+
):
29+
func_var_name = parfor._find_func_var(typemap, op, avail_vars, loc=loc)
30+
func_var = ir.Var(scope, mk_unique_var(func_var_name), loc)
31+
typemap[func_var.name] = typemap[func_var_name]
32+
func_var_def = copy.deepcopy(func_ir.get_definition(func_var_name))
33+
if (
34+
isinstance(func_var_def, ir.Expr)
35+
and func_var_def.op == "getattr"
36+
and func_var_def.attr == "sqrt"
3337
):
34-
func_var_name = parfor._find_func_var(typemap, op, avail_vars, loc=loc)
35-
func_var = ir.Var(scope, mk_unique_var(func_var_name), loc)
36-
typemap[func_var.name] = typemap[func_var_name]
37-
func_var_def = copy.deepcopy(func_ir.get_definition(func_var_name))
38-
if (
39-
isinstance(func_var_def, ir.Expr)
40-
and func_var_def.op == "getattr"
41-
and func_var_def.attr == "sqrt"
42-
):
43-
g_math_var = ir.Var(scope, mk_unique_var("$math_g_var"), loc)
44-
typemap[g_math_var.name] = types.misc.Module(math)
45-
g_math = ir.Global("math", math, loc)
46-
g_math_assign = ir.Assign(g_math, g_math_var, loc)
47-
func_var_def = ir.Expr.getattr(g_math_var, "sqrt", loc)
48-
out_ir.append(g_math_assign)
49-
ir_expr = ir.Expr.call(func_var, arg_vars, (), loc)
50-
call_typ = typemap[func_var.name].get_call_type(
51-
typingctx, tuple(typemap[a.name] for a in arg_vars), {}
52-
)
53-
calltypes[ir_expr] = call_typ
54-
el_typ = call_typ.return_type
55-
out_ir.append(ir.Assign(func_var_def, func_var, loc))
56-
out_ir.append(ir.Assign(ir_expr, expr_out_var, loc))
38+
g_math_var = ir.Var(scope, mk_unique_var("$math_g_var"), loc)
39+
typemap[g_math_var.name] = types.misc.Module(math)
40+
g_math = ir.Global("math", math, loc)
41+
g_math_assign = ir.Assign(g_math, g_math_var, loc)
42+
func_var_def = ir.Expr.getattr(g_math_var, "sqrt", loc)
43+
out_ir.append(g_math_assign)
44+
ir_expr = ir.Expr.call(func_var, arg_vars, (), loc)
45+
call_typ = typemap[func_var.name].get_call_type(
46+
typingctx, tuple(typemap[a.name] for a in arg_vars), {}
47+
)
48+
calltypes[ir_expr] = call_typ
49+
el_typ = call_typ.return_type
50+
out_ir.append(ir.Assign(func_var_def, func_var, loc))
51+
out_ir.append(ir.Assign(ir_expr, expr_out_var, loc))
5752

58-
return el_typ
53+
return el_typ
5954

60-
def _arrayexpr_tree_to_ir(
61-
func_ir,
62-
typingctx,
63-
typemap,
64-
calltypes,
65-
equiv_set,
66-
init_block,
67-
expr_out_var,
68-
expr,
69-
parfor_index_tuple_var,
70-
all_parfor_indices,
71-
avail_vars,
72-
):
73-
"""generate IR from array_expr's expr tree recursively. Assign output to
74-
expr_out_var and returns the whole IR as a list of Assign nodes.
75-
"""
76-
el_typ = typemap[expr_out_var.name]
77-
scope = expr_out_var.scope
78-
loc = expr_out_var.loc
79-
out_ir = []
8055

81-
if isinstance(expr, tuple):
82-
op, arr_expr_args = expr
83-
arg_vars = []
84-
for arg in arr_expr_args:
85-
arg_out_var = ir.Var(scope, mk_unique_var("$arg_out_var"), loc)
86-
typemap[arg_out_var.name] = el_typ
87-
out_ir += _arrayexpr_tree_to_ir(
88-
func_ir,
89-
typingctx,
90-
typemap,
91-
calltypes,
92-
equiv_set,
93-
init_block,
94-
arg_out_var,
95-
arg,
96-
parfor_index_tuple_var,
97-
all_parfor_indices,
98-
avail_vars,
56+
def _arrayexpr_tree_to_ir(
57+
func_ir,
58+
typingctx,
59+
typemap,
60+
calltypes,
61+
equiv_set,
62+
init_block,
63+
expr_out_var,
64+
expr,
65+
parfor_index_tuple_var,
66+
all_parfor_indices,
67+
avail_vars,
68+
):
69+
"""generate IR from array_expr's expr tree recursively. Assign output to
70+
expr_out_var and returns the whole IR as a list of Assign nodes.
71+
"""
72+
el_typ = typemap[expr_out_var.name]
73+
scope = expr_out_var.scope
74+
loc = expr_out_var.loc
75+
out_ir = []
76+
77+
if isinstance(expr, tuple):
78+
op, arr_expr_args = expr
79+
arg_vars = []
80+
for arg in arr_expr_args:
81+
arg_out_var = ir.Var(scope, mk_unique_var("$arg_out_var"), loc)
82+
typemap[arg_out_var.name] = el_typ
83+
out_ir += _arrayexpr_tree_to_ir(
84+
func_ir,
85+
typingctx,
86+
typemap,
87+
calltypes,
88+
equiv_set,
89+
init_block,
90+
arg_out_var,
91+
arg,
92+
parfor_index_tuple_var,
93+
all_parfor_indices,
94+
avail_vars,
95+
)
96+
arg_vars.append(arg_out_var)
97+
if op in npydecl.supported_array_operators:
98+
el_typ1 = typemap[arg_vars[0].name]
99+
if len(arg_vars) == 2:
100+
el_typ2 = typemap[arg_vars[1].name]
101+
func_typ = typingctx.resolve_function_type(
102+
op, (el_typ1, el_typ2), {}
99103
)
100-
arg_vars.append(arg_out_var)
101-
if op in npydecl.supported_array_operators:
102-
el_typ1 = typemap[arg_vars[0].name]
103-
if len(arg_vars) == 2:
104-
el_typ2 = typemap[arg_vars[1].name]
105-
func_typ = typingctx.resolve_function_type(
106-
op, (el_typ1, el_typ2), {}
107-
)
108-
ir_expr = ir.Expr.binop(op, arg_vars[0], arg_vars[1], loc)
109-
if op == operator.truediv:
110-
func_typ, ir_expr = parfor._gen_np_divide(
111-
arg_vars[0], arg_vars[1], out_ir, typemap
112-
)
113-
else:
114-
func_typ = typingctx.resolve_function_type(
115-
op, (el_typ1,), {}
116-
)
117-
ir_expr = ir.Expr.unary(op, arg_vars[0], loc)
118-
calltypes[ir_expr] = func_typ
119-
el_typ = func_typ.return_type
120-
out_ir.append(ir.Assign(ir_expr, expr_out_var, loc))
121-
for T in array_analysis.MAP_TYPES:
122-
if isinstance(op, T):
123-
# function calls are stored in variables which are not removed
124-
# op is typing_key to the variables type
125-
el_typ = _ufunc_to_parfor_instr(
126-
typemap,
127-
op,
128-
avail_vars,
129-
loc,
130-
scope,
131-
func_ir,
132-
out_ir,
133-
arg_vars,
134-
typingctx,
135-
calltypes,
136-
expr_out_var,
104+
ir_expr = ir.Expr.binop(op, arg_vars[0], arg_vars[1], loc)
105+
if op == operator.truediv:
106+
func_typ, ir_expr = parfor._gen_np_divide(
107+
arg_vars[0], arg_vars[1], out_ir, typemap
137108
)
138-
if hasattr(op, "is_dpnp_ufunc"):
109+
else:
110+
func_typ = typingctx.resolve_function_type(op, (el_typ1,), {})
111+
ir_expr = ir.Expr.unary(op, arg_vars[0], loc)
112+
calltypes[ir_expr] = func_typ
113+
el_typ = func_typ.return_type
114+
out_ir.append(ir.Assign(ir_expr, expr_out_var, loc))
115+
for T in array_analysis.MAP_TYPES:
116+
if isinstance(op, T):
117+
# function calls are stored in variables which are not removed
118+
# op is typing_key to the variables type
139119
el_typ = _ufunc_to_parfor_instr(
140120
typemap,
141121
op,
@@ -149,37 +129,58 @@ def _arrayexpr_tree_to_ir(
149129
calltypes,
150130
expr_out_var,
151131
)
152-
elif isinstance(expr, ir.Var):
153-
var_typ = typemap[expr.name]
154-
if isinstance(var_typ, types.Array):
155-
el_typ = var_typ.dtype
156-
ir_expr = parfor._gen_arrayexpr_getitem(
157-
equiv_set,
158-
expr,
159-
parfor_index_tuple_var,
160-
all_parfor_indices,
161-
el_typ,
162-
calltypes,
163-
typingctx,
164-
typemap,
165-
init_block,
166-
out_ir,
167-
)
168-
else:
169-
el_typ = var_typ
170-
ir_expr = expr
171-
out_ir.append(ir.Assign(ir_expr, expr_out_var, loc))
172-
elif isinstance(expr, ir.Const):
173-
el_typ = typing.Context().resolve_value_type(expr.value)
174-
out_ir.append(ir.Assign(expr, expr_out_var, loc))
175-
176-
if len(out_ir) == 0:
177-
raise errors.UnsupportedRewriteError(
178-
f"Don't know how to translate array expression '{expr:r}'",
179-
loc=expr.loc,
132+
if hasattr(op, "is_dpnp_ufunc"):
133+
el_typ = _ufunc_to_parfor_instr(
134+
typemap,
135+
op,
136+
avail_vars,
137+
loc,
138+
scope,
139+
func_ir,
140+
out_ir,
141+
arg_vars,
142+
typingctx,
143+
calltypes,
144+
expr_out_var,
180145
)
181-
typemap.pop(expr_out_var.name, None)
182-
typemap[expr_out_var.name] = el_typ
183-
return out_ir
146+
elif isinstance(expr, ir.Var):
147+
var_typ = typemap[expr.name]
148+
if isinstance(var_typ, types.Array):
149+
el_typ = var_typ.dtype
150+
ir_expr = parfor._gen_arrayexpr_getitem(
151+
equiv_set,
152+
expr,
153+
parfor_index_tuple_var,
154+
all_parfor_indices,
155+
el_typ,
156+
calltypes,
157+
typingctx,
158+
typemap,
159+
init_block,
160+
out_ir,
161+
)
162+
else:
163+
el_typ = var_typ
164+
ir_expr = expr
165+
out_ir.append(ir.Assign(ir_expr, expr_out_var, loc))
166+
elif isinstance(expr, ir.Const):
167+
el_typ = typing.Context().resolve_value_type(expr.value)
168+
out_ir.append(ir.Assign(expr, expr_out_var, loc))
169+
170+
if len(out_ir) == 0:
171+
raise errors.UnsupportedRewriteError(
172+
f"Don't know how to translate array expression '{expr:r}'",
173+
loc=expr.loc,
174+
)
175+
typemap.pop(expr_out_var.name, None)
176+
typemap[expr_out_var.name] = el_typ
177+
return out_ir
178+
179+
180+
def patch():
181+
"""
182+
Patches the _arrayexpr_tree_to_ir function in numba.parfor.parfor.py to
183+
support array expression nodes that were generated from dpnp expressions.
184+
"""
184185

185186
parfor._arrayexpr_tree_to_ir = _arrayexpr_tree_to_ir

0 commit comments

Comments
 (0)