Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion numba_cuda/numba/cuda/core/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def compute_use_defs(blocks):
"""
Find variable use/def per block.
"""
from numba.cuda.core import ir_utils

var_use_map = {} # { block offset -> set of vars }
var_def_map = {} # { block offset -> set of vars }
Expand Down Expand Up @@ -56,7 +57,7 @@ def compute_use_defs(blocks):
if stmt.target.name not in rhs_set:
def_set.add(stmt.target.name)

for var in stmt.list_vars():
for var in ir_utils.compat_list_vars_stmt(stmt):
# do not include locally defined vars to use-map
if var.name not in def_set:
use_set.add(var.name)
Expand Down
34 changes: 31 additions & 3 deletions numba_cuda/numba/cuda/core/ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,32 @@ def visit_vars_inner(node, callback, cbdata):
return node


def _collect_vars_callback(var, collected):
collected.append(var)
return var


def compat_list_vars_stmt(stmt):
"""List variables in a statement, robust to mixed IR node classes."""
collected = []
visit_vars_stmt(stmt, _collect_vars_callback, collected)
if collected:
return collected
return stmt.list_vars()


def compat_list_vars_node(node):
"""List variables in an IR node/expression, robust to mixed IR classes."""
collected = []
visit_vars_inner(node, _collect_vars_callback, collected)
if collected:
return collected
try:
return node.list_vars()
except AttributeError:
return ()


add_offset_to_labels_extensions = {}


Expand Down Expand Up @@ -654,7 +680,7 @@ def remove_dead(
removed = False
for label, block in blocks.items():
# find live variables at each statement to delete dead assignment
lives = {v.name for v in block.terminator.list_vars()}
lives = {v.name for v in compat_list_vars_stmt(block.terminator)}
if config.DEBUG_ARRAY_OPT >= 2:
print("remove_dead processing block", label, lives)
# find live variables at the end of block
Expand Down Expand Up @@ -768,11 +794,13 @@ def remove_dead_block(
lives -= defs
lives |= uses
else:
lives |= {v.name for v in stmt.list_vars()}
lives |= {v.name for v in compat_list_vars_stmt(stmt)}
if isinstance(stmt, ir.assign_types):
# make sure lhs is not used in rhs, e.g. a = g(a)
if isinstance(stmt.value, ir.expr_types):
rhs_vars = {v.name for v in stmt.value.list_vars()}
rhs_vars = {
v.name for v in compat_list_vars_node(stmt.value)
}
if lhs.name not in rhs_vars:
lives.remove(lhs.name)
else:
Expand Down
12 changes: 5 additions & 7 deletions numba_cuda/numba/cuda/np/npyimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,15 +1017,13 @@ def codegen(context, builder, signature, args):


@overload(np.dtype)
def numpy_dtype(dtype, align=False, copy=False):
def numpy_dtype(desc):
"""Provide an implementation so that numpy.dtype function can be lowered."""
if isinstance(dtype, (types.Literal, types.functions.NumberClass)):
if isinstance(desc, (types.Literal, types.functions.NumberClass)):

def imp(dtype, align=False, copy=False):
return _make_dtype_object(dtype)
def imp(desc):
return _make_dtype_object(desc)

return imp
else:
raise errors.NumbaTypeError(
"unknown dtype descriptor: {}".format(dtype)
)
raise errors.NumbaTypeError("unknown dtype descriptor: {}".format(desc))
26 changes: 26 additions & 0 deletions numba_cuda/numba/cuda/tests/cudapy/test_numba_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,23 @@
if HAS_NUMBA:
from numba.extending import overload

# User-facing repro shape from Issue #718: global overload + global kernel.
def issue_718_get_42():
raise NotImplementedError()

@overload(issue_718_get_42, target="cuda", inline="always")
def issue_718_overload_get_42():
def impl():
a = cuda.local.array(1, dtype=np.float32)
a[0] = 42.0
return a[0]

return impl

@cuda.jit
def issue_718_kernel(a):
a[0] = issue_718_get_42()


@skip_on_cudasim("Simulator does not support the extension API")
@unittest.skipUnless(HAS_NUMBA, "Tests interoperability with Numba")
Expand All @@ -33,3 +50,12 @@ def kernel(a):
a = np.empty(1, dtype=np.float32)
kernel[1, 1](a)
np.testing.assert_equal(a[0], 42)

def test_overload_inline_always_local_array(self):
# From Issue #718
# Keep the test body as close as possible to end-user kernel launch.
a = np.empty(1, dtype=np.float32)
d_a = cuda.to_device(a)
issue_718_kernel[1, 1](d_a)
d_a.copy_to_host(a)
np.testing.assert_equal(a[0], 42.0)
30 changes: 15 additions & 15 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading