Skip to content

Commit 17b5543

Browse files
authored
Add intent-based ref handling and out-return support for bindings (#278)
In ISO C++, whether an argument is mutated or not is not directly inferable from the signature itself. For example, ```c++ void set_as_42(int &a) { a = 42; } const int& operator+ (const int&, const int&); ``` both accepts a reference as input, but whether the function modifies the argument is unknown until further analysis into the body is performed. This creates confusions for language bindings on whether the arguments should be passed by reference or value, as other language may dictate the argument passing semantics differently from C++. Certain compiler provides additional annotation features to denote them. Such as [SAL](https://learn.microsoft.com/en-us/cpp/code-quality/best-practices-and-examples-sal?view=msvc-170). In this PR, Numbast introduces an `argument intent` option that allows user to configure argument passing mode on per-function, per-parameter basis. Per argument, the following options are available: ``` - in (default, arguments are passed as-is by value) - inout_ptr (argument is both used as an input mutated, exposed in Numba as a CPointer type to itself) - out_ptr (argument is used to store output, thus mutated; exposed in Numba as a CPointer type to itself) - out_return (return arguments are pre-allocated on stack, returned in Numba bindings) ``` Take the following C++ signature as an example: ```c++ // A C++ function that processes `input`, writes result to `out`, and returns an exit code. int mutative(int &out, int input); ``` A typical argument intent setup for Numbast looks like: ``` {"mutative": {"out": "out_return"}} ``` This indicates that argument `out` is returned as the functions return value in the corresponding binding. And since this function already has a return value, the binding will now return a tuple of ints, with first corresponds to the exit code, and the second corresponds to the result. The Python binding signature: ```python def mutative(input: numba.int32) -> Tuple[numba.int32, numba.int32]: ... ``` Alternatively, if intent is set to: `"out": "out_ptr"`, the signature becomes: ```python def mutative(out: CPointer(int32), input: int32) -> int32: ... ``` <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Introduced a per-function argument-intent system to control parameter semantics (in, inout, out, out-return) and pointer-passing behavior. * Support for returning multiple values via out-parameters alongside regular returns. * Static binding generation and rendering now accept and propagate per-function intent overrides. * **Tests** * Added end-to-end tests, fixtures and test data exercising out-parameters, in/out-pointer semantics, out-return behavior, and mutative device functions. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Michael Wang <isVoid@users.noreply.github.com>
1 parent 036f2ad commit 17b5543

21 files changed

+1277
-90
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ repos:
3737
- id: codespell
3838
additional_dependencies:
3939
- tomli
40-
args: ["--toml", "pyproject.toml"]
40+
args: ["--toml", "pyproject.toml", "--ignore-words-list", "inout"]
4141
- repo: https://github.com/google/yamlfmt
4242
rev: v0.16.0
4343
hooks:

numbast/src/numbast/args.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77

88

99
def prepare_ir_types(
10-
context: CUDATargetContext, argtys: list[ir.Type]
10+
context: CUDATargetContext,
11+
argtys: list[ir.Type],
12+
*,
13+
pass_ptr_mask: list[bool] | None = None,
1114
) -> list[ir.Type]:
1215
"""
1316
Prepare IR types for passing arguments via pointers in function calls.
@@ -28,4 +31,22 @@ def prepare_ir_types(
2831
list[ir.Type]
2932
List of pointer types wrapping the value types of each argument.
3033
"""
31-
return [ir.PointerType(context.get_value_type(argty)) for argty in argtys]
34+
if pass_ptr_mask is None:
35+
pass_ptr_mask = [False] * len(argtys)
36+
37+
if len(pass_ptr_mask) != len(argtys):
38+
raise ValueError(
39+
f"pass_ptr_mask length ({len(pass_ptr_mask)}) must match argtys length ({len(argtys)})"
40+
)
41+
42+
ir_types: list[ir.Type] = []
43+
for argty, passthrough in zip(argtys, pass_ptr_mask):
44+
vty = context.get_value_type(argty)
45+
if passthrough and isinstance(vty, ir.PointerType):
46+
# Pass pointer-typed values directly (e.g. C++ T& mapped to CPointer(T))
47+
ir_types.append(vty)
48+
else:
49+
# Default ABI: pass pointer-to-value
50+
ir_types.append(ir.PointerType(vty))
51+
52+
return ir_types

numbast/src/numbast/callconv.py

Lines changed: 158 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
14
from numbast.args import prepare_ir_types
5+
from numbast.intent import IntentPlan
6+
7+
# NBST:BEGIN_CALLCONV
28
from numba.cuda import types, cgutils
39

410
from llvmlite import ir
@@ -36,27 +42,133 @@ def __call__(self, builder, context, sig, args):
3642

3743

3844
class FunctionCallConv(BaseCallConv):
45+
def __init__(
46+
self,
47+
itanium_mangled_name: str,
48+
shim_writer: object,
49+
shim_code: str,
50+
*,
51+
arg_is_ref: list[bool] | None = None,
52+
intent_plan: IntentPlan | None = None,
53+
out_return_types: list[types.Type] | None = None,
54+
cxx_return_type: types.Type | None = None,
55+
):
56+
super().__init__(itanium_mangled_name, shim_writer, shim_code)
57+
self._arg_is_ref = list(arg_is_ref) if arg_is_ref is not None else None
58+
self._intent_plan = intent_plan
59+
self._out_return_types = (
60+
list(out_return_types) if out_return_types is not None else None
61+
)
62+
self._cxx_return_type = cxx_return_type
63+
3964
def _lower_impl(self, builder, context, sig, args):
40-
return_type = sig.return_type
65+
# Numba-visible return type may differ from the underlying C++ return type
66+
# when out_return parameters are enabled (tuple returns, etc.).
67+
cxx_return_type = (
68+
self._cxx_return_type
69+
if self._cxx_return_type is not None
70+
else sig.return_type
71+
)
4172
# 1. Prepare return value pointer
42-
if return_type == types.void:
73+
if cxx_return_type == types.void:
4374
# Void return type in C++ is shimmed as int& ignored
4475
retval_ty = ir.IntType(32)
4576
retval_ptr = builder.alloca(retval_ty, name="ignored")
4677
else:
47-
retval_ty = context.get_value_type(return_type)
78+
retval_ty = context.get_value_type(cxx_return_type)
4879
retval_ptr = builder.alloca(retval_ty, name="retval")
4980

5081
# 2. Prepare arguments
51-
arg_pointer_types = prepare_ir_types(context, sig.args)
52-
53-
# All arguments are passed by pointer
54-
ptrs = [
55-
cgutils.alloca_once(builder, context.get_value_type(argty))
56-
for argty in sig.args
57-
]
58-
for ptr, argty, arg in zip(ptrs, sig.args, args):
59-
builder.store(arg, ptr, align=getattr(argty, "alignof_", None))
82+
if self._intent_plan is None:
83+
pass_ptr_mask = (
84+
self._arg_is_ref
85+
if self._arg_is_ref is not None
86+
else [False] * len(sig.args)
87+
)
88+
arg_pointer_types = prepare_ir_types(
89+
context, sig.args, pass_ptr_mask=pass_ptr_mask
90+
)
91+
else:
92+
plan = self._intent_plan
93+
if len(sig.args) != len(plan.visible_param_indices):
94+
raise ValueError(
95+
"Signature args do not match intent plan visible params: "
96+
f"sig has {len(sig.args)} args but plan expects {len(plan.visible_param_indices)}"
97+
)
98+
if len(plan.pass_ptr_mask) != len(sig.args):
99+
raise ValueError(
100+
"Intent plan pass_ptr_mask length does not match signature args length: "
101+
f"{len(plan.pass_ptr_mask)} != {len(sig.args)}"
102+
)
103+
if plan.out_return_indices:
104+
if self._out_return_types is None:
105+
raise ValueError(
106+
"out_return intent plan requires out_return_types to be provided"
107+
)
108+
if len(self._out_return_types) != len(plan.out_return_indices):
109+
raise ValueError(
110+
"out_return_types length does not match intent plan out_return_indices: "
111+
f"{len(self._out_return_types)} != {len(plan.out_return_indices)}"
112+
)
113+
arg_pointer_types = [] # computed below alongside ptrs
114+
115+
# ABI:
116+
# - default: pass pointer-to-value to shim (alloca + store)
117+
# - for C++ reference args mapped to CPointer(T): pass pointer value directly
118+
ptrs = []
119+
out_return_ptrs: list[tuple[types.Type, ir.Value]] = []
120+
if self._intent_plan is None:
121+
for argty, arg, passthrough in zip(sig.args, args, pass_ptr_mask):
122+
vty = context.get_value_type(argty)
123+
if passthrough and isinstance(vty, ir.PointerType):
124+
ptrs.append(arg)
125+
else:
126+
ptr = cgutils.alloca_once(builder, vty)
127+
builder.store(
128+
arg, ptr, align=getattr(argty, "alignof_", None)
129+
)
130+
ptrs.append(ptr)
131+
else:
132+
plan = self._intent_plan
133+
n_orig = len(plan.intents)
134+
# Map original parameter index -> visible signature position / out_return position
135+
orig_to_vis = [None] * n_orig
136+
for vis_pos, orig_idx in enumerate(plan.visible_param_indices):
137+
orig_to_vis[orig_idx] = vis_pos
138+
orig_to_out = [None] * n_orig
139+
for out_pos, orig_idx in enumerate(plan.out_return_indices):
140+
orig_to_out[orig_idx] = out_pos
141+
142+
for orig_idx in range(n_orig):
143+
out_pos = orig_to_out[orig_idx]
144+
if out_pos is not None:
145+
out_nbty = self._out_return_types[out_pos]
146+
vty = context.get_value_type(out_nbty)
147+
ptr = cgutils.alloca_once(builder, vty)
148+
ptrs.append(ptr)
149+
arg_pointer_types.append(ir.PointerType(vty))
150+
out_return_ptrs.append((out_nbty, ptr))
151+
continue
152+
153+
vis_pos = orig_to_vis[orig_idx]
154+
if vis_pos is None:
155+
raise ValueError(
156+
f"Internal error: original param {orig_idx} is neither visible nor out_return"
157+
)
158+
argty = sig.args[vis_pos]
159+
arg = args[vis_pos]
160+
passthrough = bool(plan.pass_ptr_mask[vis_pos])
161+
vty = context.get_value_type(argty)
162+
if passthrough and isinstance(vty, ir.PointerType):
163+
ptrs.append(arg)
164+
arg_pointer_types.append(vty)
165+
else:
166+
ptr = cgutils.alloca_once(builder, vty)
167+
builder.store(
168+
arg, ptr, align=getattr(argty, "alignof_", None)
169+
)
170+
ptrs.append(ptr)
171+
arg_pointer_types.append(ir.PointerType(vty))
60172

61173
# 3. Declare shim
62174
# Shim signature: int (retval_type*, arg0_type*, ...)
@@ -71,9 +183,39 @@ def _lower_impl(self, builder, context, sig, args):
71183
builder.call(fn, (retval_ptr, *ptrs))
72184

73185
# 5. Return
74-
if return_type == types.void:
75-
return None
76-
else:
186+
if (
187+
self._intent_plan is None
188+
or not self._intent_plan.out_return_indices
189+
):
190+
if cxx_return_type == types.void:
191+
return None
77192
return builder.load(
78-
retval_ptr, align=getattr(return_type, "alignof_", None)
193+
retval_ptr, align=getattr(cxx_return_type, "alignof_", None)
194+
)
195+
196+
# out_return enabled: return either a value or a tuple (ret, out1, out2, ...)
197+
ret_vals: list[ir.Value] = []
198+
if cxx_return_type != types.void:
199+
ret_vals.append(
200+
builder.load(
201+
retval_ptr, align=getattr(cxx_return_type, "alignof_", None)
202+
)
203+
)
204+
for out_ty, out_ptr in out_return_ptrs:
205+
ret_vals.append(
206+
builder.load(out_ptr, align=getattr(out_ty, "alignof_", None))
207+
)
208+
209+
# If Numba-visible return is a tuple, use context.make_tuple.
210+
# Otherwise (void + single out), return the single out value directly.
211+
if hasattr(sig.return_type, "types"):
212+
return context.make_tuple(builder, sig.return_type, ret_vals)
213+
if len(ret_vals) != 1:
214+
raise ValueError(
215+
"Non-tuple return type requires exactly one return value; "
216+
f"got {len(ret_vals)}"
79217
)
218+
return ret_vals[0]
219+
220+
221+
# NBST:END_CALLCONV

0 commit comments

Comments
 (0)