-
Notifications
You must be signed in to change notification settings - Fork 57
Description
@shwina reported that the following testcase (reduced from CCCL):
from numba import cuda
from numba import types
from numba.core.extending import intrinsic
def foo(x):
return 2 * x
# This works - compiling a plain function
print("Compiling plain function...")
cuda.compile(foo, (types.int32,), output="ltoir")
print("✓ Plain function compilation succeeded")
# Now create a wrapper that takes void* arguments
# This is what breaks in newer numba versions
def create_void_ptr_wrapper():
"""Create a wrapper that takes void* input and output pointers."""
# Make foo a device function
foo_device = cuda.jit(device=True)(foo)
# The inner signature: int32 -> int32
inner_sig = types.int32(types.int32)
# The wrapper signature: void(void*, void*) - input ptr, output ptr
wrapper_sig = types.void(types.voidptr, types.voidptr)
@intrinsic
def wrapper_impl(typingctx, arg0, arg1):
def codegen(context, builder, sig, args):
input_ptr, output_ptr = args
# Cast input void* to int32*, load value
int32_llvm_type = context.get_value_type(types.int32)
typed_input_ptr = builder.bitcast(input_ptr, int32_llvm_type.as_pointer())
input_val = builder.load(typed_input_ptr)
# Call the inner function
cres = context.compile_subroutine(builder, foo_device, inner_sig, caching=False)
result = context.call_internal(builder, cres.fndesc, inner_sig, [input_val])
# Cast output void* to int32*, store result
typed_output_ptr = builder.bitcast(output_ptr, int32_llvm_type.as_pointer())
builder.store(result, typed_output_ptr)
return context.get_dummy_value()
return wrapper_sig, codegen
def wrapper_func(input_ptr, output_ptr):
return wrapper_impl(input_ptr, output_ptr)
return wrapper_func, wrapper_sig
print("\nCreating void* wrapper...")
wrapper, wrapper_sig = create_void_ptr_wrapper()
# This is what breaks - trying to compile the wrapper
print("\nCompiling void* wrapper to ltoir...")
cuda.compile(wrapper, wrapper_sig.args, output="ltoir")results in an exception with the following end of the traceback:
File "/home/gmarkall/numbadev/numba-cuda/numba_cuda/numba/cuda/core/base.py", line 1305, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/gmarkall/numbadev/issues/cccl-call-internal/minimal.py", line 43, in codegen
result = context.call_internal(builder, cres.fndesc, inner_sig, [input_val])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gmarkall/numbadev/numba-cuda/numba_cuda/numba/cuda/core/base.py", line 964, in call_internal
with cgutils.if_unlikely(builder, status.is_error):
^^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'is_error'
Initially it looks like call_internal() is missing a check for status being None. The call_internal() function is:
numba-cuda/numba_cuda/numba/cuda/core/base.py
Lines 956 to 968 in 9de460e
| def call_internal(self, builder, fndesc, sig, args): | |
| """ | |
| Given the function descriptor of an internally compiled function, | |
| emit a call to that function with the given arguments. | |
| """ | |
| status, res = self.call_internal_no_propagate( | |
| builder, fndesc, sig, args | |
| ) | |
| with cgutils.if_unlikely(builder, status.is_error): | |
| fndesc.call_conv.return_status_propagate(builder, status) | |
| res = imputils.fix_returning_optional(self, builder, sig, status, res) | |
| return res |
This is has some similarity with imputils.user_function.imp():
numba-cuda/numba_cuda/numba/cuda/core/imputils.py
Lines 215 to 224 in 9de460e
| def imp(context, builder, sig, args, fndesc=fndesc): | |
| func = fndesc.declare_function(builder.module) | |
| # env=None assumes this is a nopython function | |
| status, retval = fndesc.call_conv.call_function( | |
| builder, func, fndesc.restype, fndesc.argtypes, args | |
| ) | |
| if status is not None: | |
| with cgutils.if_unlikely(builder, status.is_error): | |
| fndesc.call_conv.return_status_propagate(builder, status) |
The if status is not None guard was added in #717 to guard against status being None when the called function has the C calling convention.
So it does look like we are missing a guard in call_internal() to protect against this too, such that the following diff is applied:
diff --git a/numba_cuda/numba/cuda/core/base.py b/numba_cuda/numba/cuda/core/base.py
index c518f9cb..2bdfe846 100644
--- a/numba_cuda/numba/cuda/core/base.py
+++ b/numba_cuda/numba/cuda/core/base.py
@@ -961,8 +961,9 @@ class BaseContext:
status, res = self.call_internal_no_propagate(
builder, fndesc, sig, args
)
- with cgutils.if_unlikely(builder, status.is_error):
- fndesc.call_conv.return_status_propagate(builder, status)
+ if status is not None:
+ with cgutils.if_unlikely(builder, status.is_error):
+ fndesc.call_conv.return_status_propagate(builder, status)
res = imputils.fix_returning_optional(self, builder, sig, status, res)
return resApplying this does indeed allow the reproducer to execute to completion. However, it is not addressing the root cause. We need to note from the reproducer that function being called is foo_device(), which was a normal jitted function. Bits from the reproducer above:
def foo(x):
return 2 * x
# Make foo a device function
foo_device = cuda.jit(device=True)(foo)
# Call the inner function
cres = context.compile_subroutine(builder, foo_device, inner_sig, caching=False)
result = context.call_internal(builder, cres.fndesc, inner_sig, [input_val])foo_device should have the Numba calling convention, because it is a jitted function. For the reproducer above this doesn't make a difference because foo does not raise an exception, but any code that iterates over anything (e.g. range() will fail because status is used to hold the StopIteration exception.
Other parts of the problem are:
- We now store the calling convention in the flags, so that it can be picked up in the pipeline to create the correct mangler implementation:
numba-cuda/numba_cuda/numba/cuda/core/typed_passes.py
Lines 329 to 333 in 9de460e
call_conv = flags.call_conv if call_conv is None: call_conv = CUDACallConv(state.targetctx) mangler = call_conv.mangler - The
compile_subroutine()call is creating thecresthat has the function compiled with the wrong calling convention. It ends up in_compile_subroutine_no_cache(), which applies flags for the top level function if no flags are attached to the current function:numba-cuda/numba_cuda/numba/cuda/target.py
Lines 398 to 408 in 9de460e
with global_compiler_lock: codegen = self.codegen() library = codegen.create_library(impl.__name__) if flags is None: cstk = targetconfig.ConfigStack() if cstk: flags = cstk.top().copy() else: msg = "There should always be a context stack; none found." warnings.warn(msg, NumbaWarning) flags = CUDAFlags() - There is also a comment about this being part of a bugfix for another issue (I need to look into what that is / was):
numba-cuda/numba_cuda/numba/cuda/target.py
Lines 391 to 393 in 9de460e
# Overrides numba.core.base.BaseContext._compile_subroutine_no_cache(). # Modified to use flags from the context stack if they are not provided # (pending a fix in Numba upstream). - When calling
cuda.compile()withabi="c", the top-level function has the C calling convention.
I think the right fix is probably to move the mangler into the calling convention and therefore to not need to put the calling convention in the flags (so that it can't be picked up accidentally by a callee) but this needs trying out to find out if there are further issues.