|
7 | 7 | to use it with dpnp instead of numpy.
|
8 | 8 | """
|
9 | 9 |
|
| 10 | + |
| 11 | +import copy |
| 12 | +import math |
| 13 | +import operator |
10 | 14 | import warnings
|
11 | 15 |
|
12 |
| -from numba.core import config, errors, ir, types |
| 16 | +from numba.core import config, errors, ir, types, typing |
13 | 17 | from numba.core.compiler_machinery import register_pass
|
14 | 18 | from numba.core.ir_utils import (
|
15 | 19 | dprint_func_ir,
|
|
19 | 23 | )
|
20 | 24 | from numba.core.typed_passes import ParforPass as NumpyParforPass
|
21 | 25 | from numba.core.typed_passes import _reload_parfors
|
| 26 | +from numba.core.typing import npydecl |
| 27 | +from numba.parfors import array_analysis, parfor |
22 | 28 | from numba.parfors.parfor import (
|
23 | 29 | ConvertInplaceBinop,
|
24 | 30 | ConvertLoopPass,
|
|
36 | 42 | )
|
37 | 43 | from numba.stencils.stencilparfor import StencilPass
|
38 | 44 |
|
39 |
| -from numba_dpex.numba_patches.patch_arrayexpr_tree_to_ir import ( |
40 |
| - _arrayexpr_tree_to_ir, |
41 |
| -) |
42 |
| - |
43 | 45 |
|
44 | 46 | class ConvertDPNPPass(ConvertNumpyPass):
|
45 | 47 | """
|
@@ -249,3 +251,184 @@ def run_pass(self, state):
|
249 | 251 | # Add reload function to initialize the parallel backend.
|
250 | 252 | state.reload_init.append(_reload_parfors)
|
251 | 253 | return True
|
| 254 | + |
| 255 | + |
| 256 | +def _ufunc_to_parfor_instr( |
| 257 | + typemap, |
| 258 | + op, |
| 259 | + avail_vars, |
| 260 | + loc, |
| 261 | + scope, |
| 262 | + func_ir, |
| 263 | + out_ir, |
| 264 | + arg_vars, |
| 265 | + typingctx, |
| 266 | + calltypes, |
| 267 | + expr_out_var, |
| 268 | +): |
| 269 | + func_var_name = parfor._find_func_var(typemap, op, avail_vars, loc=loc) |
| 270 | + func_var = ir.Var(scope, mk_unique_var(func_var_name), loc) |
| 271 | + typemap[func_var.name] = typemap[func_var_name] |
| 272 | + func_var_def = copy.deepcopy(func_ir.get_definition(func_var_name)) |
| 273 | + if ( |
| 274 | + isinstance(func_var_def, ir.Expr) |
| 275 | + and func_var_def.op == "getattr" |
| 276 | + and func_var_def.attr == "sqrt" |
| 277 | + ): |
| 278 | + g_math_var = ir.Var(scope, mk_unique_var("$math_g_var"), loc) |
| 279 | + typemap[g_math_var.name] = types.misc.Module(math) |
| 280 | + g_math = ir.Global("math", math, loc) |
| 281 | + g_math_assign = ir.Assign(g_math, g_math_var, loc) |
| 282 | + func_var_def = ir.Expr.getattr(g_math_var, "sqrt", loc) |
| 283 | + out_ir.append(g_math_assign) |
| 284 | + ir_expr = ir.Expr.call(func_var, arg_vars, (), loc) |
| 285 | + call_typ = typemap[func_var.name].get_call_type( |
| 286 | + typingctx, tuple(typemap[a.name] for a in arg_vars), {} |
| 287 | + ) |
| 288 | + calltypes[ir_expr] = call_typ |
| 289 | + el_typ = call_typ.return_type |
| 290 | + out_ir.append(ir.Assign(func_var_def, func_var, loc)) |
| 291 | + out_ir.append(ir.Assign(ir_expr, expr_out_var, loc)) |
| 292 | + |
| 293 | + return el_typ |
| 294 | + |
| 295 | + |
| 296 | +def _arrayexpr_tree_to_ir( |
| 297 | + func_ir, |
| 298 | + typingctx, |
| 299 | + typemap, |
| 300 | + calltypes, |
| 301 | + equiv_set, |
| 302 | + init_block, |
| 303 | + expr_out_var, |
| 304 | + expr, |
| 305 | + parfor_index_tuple_var, |
| 306 | + all_parfor_indices, |
| 307 | + avail_vars, |
| 308 | +): |
| 309 | + """generate IR from array_expr's expr tree recursively. Assign output to |
| 310 | + expr_out_var and returns the whole IR as a list of Assign nodes. |
| 311 | + """ |
| 312 | + el_typ = typemap[expr_out_var.name] |
| 313 | + scope = expr_out_var.scope |
| 314 | + loc = expr_out_var.loc |
| 315 | + out_ir = [] |
| 316 | + |
| 317 | + if isinstance(expr, tuple): |
| 318 | + op, arr_expr_args = expr |
| 319 | + arg_vars = [] |
| 320 | + for arg in arr_expr_args: |
| 321 | + arg_out_var = ir.Var(scope, mk_unique_var("$arg_out_var"), loc) |
| 322 | + typemap[arg_out_var.name] = el_typ |
| 323 | + out_ir += _arrayexpr_tree_to_ir( |
| 324 | + func_ir, |
| 325 | + typingctx, |
| 326 | + typemap, |
| 327 | + calltypes, |
| 328 | + equiv_set, |
| 329 | + init_block, |
| 330 | + arg_out_var, |
| 331 | + arg, |
| 332 | + parfor_index_tuple_var, |
| 333 | + all_parfor_indices, |
| 334 | + avail_vars, |
| 335 | + ) |
| 336 | + arg_vars.append(arg_out_var) |
| 337 | + if op in npydecl.supported_array_operators: |
| 338 | + el_typ1 = typemap[arg_vars[0].name] |
| 339 | + if len(arg_vars) == 2: |
| 340 | + el_typ2 = typemap[arg_vars[1].name] |
| 341 | + func_typ = typingctx.resolve_function_type( |
| 342 | + op, (el_typ1, el_typ2), {} |
| 343 | + ) |
| 344 | + ir_expr = ir.Expr.binop(op, arg_vars[0], arg_vars[1], loc) |
| 345 | + if op == operator.truediv: |
| 346 | + func_typ, ir_expr = parfor._gen_np_divide( |
| 347 | + arg_vars[0], arg_vars[1], out_ir, typemap |
| 348 | + ) |
| 349 | + else: |
| 350 | + func_typ = typingctx.resolve_function_type(op, (el_typ1,), {}) |
| 351 | + ir_expr = ir.Expr.unary(op, arg_vars[0], loc) |
| 352 | + calltypes[ir_expr] = func_typ |
| 353 | + el_typ = func_typ.return_type |
| 354 | + out_ir.append(ir.Assign(ir_expr, expr_out_var, loc)) |
| 355 | + for T in array_analysis.MAP_TYPES: |
| 356 | + if isinstance(op, T): |
| 357 | + # function calls are stored in variables which are not removed |
| 358 | + # op is typing_key to the variables type |
| 359 | + func_var_name = parfor._find_func_var( |
| 360 | + typemap, op, avail_vars, loc=loc |
| 361 | + ) |
| 362 | + func_var = ir.Var(scope, mk_unique_var(func_var_name), loc) |
| 363 | + typemap[func_var.name] = typemap[func_var_name] |
| 364 | + func_var_def = copy.deepcopy( |
| 365 | + func_ir.get_definition(func_var_name) |
| 366 | + ) |
| 367 | + if ( |
| 368 | + isinstance(func_var_def, ir.Expr) |
| 369 | + and func_var_def.op == "getattr" |
| 370 | + and func_var_def.attr == "sqrt" |
| 371 | + ): |
| 372 | + g_math_var = ir.Var( |
| 373 | + scope, mk_unique_var("$math_g_var"), loc |
| 374 | + ) |
| 375 | + typemap[g_math_var.name] = types.misc.Module(math) |
| 376 | + g_math = ir.Global("math", math, loc) |
| 377 | + g_math_assign = ir.Assign(g_math, g_math_var, loc) |
| 378 | + func_var_def = ir.Expr.getattr(g_math_var, "sqrt", loc) |
| 379 | + out_ir.append(g_math_assign) |
| 380 | + ir_expr = ir.Expr.call(func_var, arg_vars, (), loc) |
| 381 | + call_typ = typemap[func_var.name].get_call_type( |
| 382 | + typingctx, tuple(typemap[a.name] for a in arg_vars), {} |
| 383 | + ) |
| 384 | + calltypes[ir_expr] = call_typ |
| 385 | + el_typ = call_typ.return_type |
| 386 | + out_ir.append(ir.Assign(func_var_def, func_var, loc)) |
| 387 | + out_ir.append(ir.Assign(ir_expr, expr_out_var, loc)) |
| 388 | + # NUMBA_DPEX: is_dpnp_func check was added |
| 389 | + if hasattr(op, "is_dpnp_ufunc"): |
| 390 | + el_typ = _ufunc_to_parfor_instr( |
| 391 | + typemap, |
| 392 | + op, |
| 393 | + avail_vars, |
| 394 | + loc, |
| 395 | + scope, |
| 396 | + func_ir, |
| 397 | + out_ir, |
| 398 | + arg_vars, |
| 399 | + typingctx, |
| 400 | + calltypes, |
| 401 | + expr_out_var, |
| 402 | + ) |
| 403 | + elif isinstance(expr, ir.Var): |
| 404 | + var_typ = typemap[expr.name] |
| 405 | + if isinstance(var_typ, types.Array): |
| 406 | + el_typ = var_typ.dtype |
| 407 | + ir_expr = parfor._gen_arrayexpr_getitem( |
| 408 | + equiv_set, |
| 409 | + expr, |
| 410 | + parfor_index_tuple_var, |
| 411 | + all_parfor_indices, |
| 412 | + el_typ, |
| 413 | + calltypes, |
| 414 | + typingctx, |
| 415 | + typemap, |
| 416 | + init_block, |
| 417 | + out_ir, |
| 418 | + ) |
| 419 | + else: |
| 420 | + el_typ = var_typ |
| 421 | + ir_expr = expr |
| 422 | + out_ir.append(ir.Assign(ir_expr, expr_out_var, loc)) |
| 423 | + elif isinstance(expr, ir.Const): |
| 424 | + el_typ = typing.Context().resolve_value_type(expr.value) |
| 425 | + out_ir.append(ir.Assign(expr, expr_out_var, loc)) |
| 426 | + |
| 427 | + if len(out_ir) == 0: |
| 428 | + raise errors.UnsupportedRewriteError( |
| 429 | + f"Don't know how to translate array expression '{expr:r}'", |
| 430 | + loc=expr.loc, |
| 431 | + ) |
| 432 | + typemap.pop(expr_out_var.name, None) |
| 433 | + typemap[expr_out_var.name] = el_typ |
| 434 | + return out_ir |
0 commit comments