Skip to content

Commit a0d78df

Browse files
committed
Generate call to dpnp instead of np for divide
1 parent 2dce65a commit a0d78df

File tree

3 files changed

+39
-8
lines changed

3 files changed

+39
-8
lines changed

numba_dpex/core/parfors/parfor_pass.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import operator
1414
import warnings
1515

16+
import dpnp
1617
from numba.core import config, errors, ir, types, typing
1718
from numba.core.compiler_machinery import register_pass
1819
from numba.core.ir_utils import (
@@ -42,6 +43,8 @@
4243
)
4344
from numba.stencils.stencilparfor import StencilPass
4445

46+
from numba_dpex.core.typing import dpnpdecl
47+
4548

4649
class ConvertDPNPPass(ConvertNumpyPass):
4750
"""
@@ -293,6 +296,39 @@ def _ufunc_to_parfor_instr(
293296
return el_typ
294297

295298

299+
def get_dpnp_ufunc_typ(func):
300+
"""get type of the incoming function from builtin registry"""
301+
for k, v in dpnpdecl.registry.globals:
302+
if k == func:
303+
return v
304+
raise RuntimeError("type for func ", func, " not found")
305+
306+
307+
def _gen_dpnp_divide(arg1, arg2, out_ir, typemap):
308+
"""generate np.divide() instead of / for array_expr to get numpy error model
309+
like inf for division by zero (test_division_by_zero).
310+
"""
311+
scope = arg1.scope
312+
loc = arg1.loc
313+
g_np_var = ir.Var(scope, mk_unique_var("$np_g_var"), loc)
314+
typemap[g_np_var.name] = types.misc.Module(dpnp)
315+
g_np = ir.Global("dpnp", dpnp, loc)
316+
g_np_assign = ir.Assign(g_np, g_np_var, loc)
317+
# attr call: div_attr = getattr(g_np_var, divide)
318+
div_attr_call = ir.Expr.getattr(g_np_var, "divide", loc)
319+
attr_var = ir.Var(scope, mk_unique_var("$div_attr"), loc)
320+
func_var_typ = get_dpnp_ufunc_typ(dpnp.divide)
321+
typemap[attr_var.name] = func_var_typ
322+
attr_assign = ir.Assign(div_attr_call, attr_var, loc)
323+
# divide call: div_attr(arg1, arg2)
324+
div_call = ir.Expr.call(attr_var, [arg1, arg2], (), loc)
325+
func_typ = func_var_typ.get_call_type(
326+
typing.Context(), [typemap[arg1.name], typemap[arg2.name]], {}
327+
)
328+
out_ir.extend([g_np_assign, attr_assign])
329+
return func_typ, div_call
330+
331+
296332
def _arrayexpr_tree_to_ir(
297333
func_ir,
298334
typingctx,
@@ -343,7 +379,8 @@ def _arrayexpr_tree_to_ir(
343379
)
344380
ir_expr = ir.Expr.binop(op, arg_vars[0], arg_vars[1], loc)
345381
if op == operator.truediv:
346-
func_typ, ir_expr = parfor._gen_np_divide(
382+
# NUMBA_DPEX: is_dpnp_func check was added
383+
func_typ, ir_expr = _gen_dpnp_divide(
347384
arg_vars[0], arg_vars[1], out_ir, typemap
348385
)
349386
else:

numba_dpex/dpnp_iface/dpnp_ufunc_db.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,6 @@ def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db):
8181
op.types = npop.types
8282
op.is_dpnp_ufunc = True
8383
cp = copy.copy(_ufunc_db[npop])
84-
if "'divide'" in str(npop):
85-
# TODO: why do we need to do it only for divide?
86-
# https://github.com/IntelPython/numba-dpex/issues/1270
87-
ufunc_db.update({npop: cp})
8884
ufunc_db.update({op: cp})
8985
for key in list(ufunc_db[op].keys()):
9086
if (

numba_dpex/kernel_api_impl/spirv/target.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from numba.core.callconv import MinimalCallConv
1919
from numba.core.target_extension import GPU, target_registry
2020
from numba.core.types.scalars import IntEnumClass
21-
from numba.core.typing import cmathdecl, enumdecl, npydecl
21+
from numba.core.typing import cmathdecl, enumdecl
2222

2323
from numba_dpex.core.datamodel.models import _init_data_model_manager
2424
from numba_dpex.core.types import IntEnumLiteral
@@ -108,8 +108,6 @@ def load_additional_registries(self):
108108
self.install_registry(ocldecl.registry)
109109
self.install_registry(mathdecl.registry)
110110
self.install_registry(cmathdecl.registry)
111-
# TODO: https://github.com/IntelPython/numba-dpex/issues/1270
112-
self.install_registry(npydecl.registry)
113111
self.install_registry(dpnpdecl.registry)
114112
self.install_registry(enumdecl.registry)
115113

0 commit comments

Comments
 (0)