Skip to content

Commit f91d941

Browse files
authored
Centralize Calling Convention and Simplify Binding Generation (#271)
This PR centralizes the C++ FFI convention to callconv.py and reuses the callconv everywhere to reduce redundancy. Additionally, CUDA13 header adds a few compiler native keywords / native types that prevents clang from parsing the code without syntax error. This PR adds in patches to prevent those error. Note that they should not be considered as official support for the patched keyword / datatypes. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Unified call-convention framework for C++ function bindings, constructors, conversion operators, and methods, simplifying shim generation and lowering. * **New Features** * Added call-convention utilities into generated bindings and tooling to centralize shim/call logic. * Introduced a utility to prepare IR argument types for CUDA targets. * **Bug Fixes / Chores** * Ensure shim include path is always resolved (now fails fast if missing). * Include header shim files in distribution wheels. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Michael Wang <isVoid@users.noreply.github.com>
1 parent 553200d commit f91d941

File tree

15 files changed

+287
-593
lines changed

15 files changed

+287
-593
lines changed

ast_canopy/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,12 @@ target_link_libraries(pylibastcanopy
3232
PUBLIC astcanopy::astcanopy)
3333

3434
install(TARGETS pylibastcanopy LIBRARY DESTINATION ast_canopy)
35+
36+
# Ensure header shims are shipped inside the wheel under:
37+
# site-packages/ast_canopy/shim_include/
38+
# scikit-build-core builds wheels from the CMake install tree, so non-built
39+
# sources must be installed explicitly.
40+
install(
41+
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/ast_canopy/shim_include
42+
DESTINATION ast_canopy
43+
)

ast_canopy/ast_canopy/api.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,18 @@
2828
logger = logging.getLogger(f"AST_Canopy.{__name__}")
2929

3030

31-
def _get_shim_include_dir() -> str | None:
32-
"""Return the absolute path to the local shim include directory, if present."""
31+
def _get_shim_include_dir() -> str:
32+
"""Return the absolute path to the local shim include directory"""
3333
here = os.path.dirname(__file__)
3434
shim_dir = os.path.join(here, "shim_include")
35-
return shim_dir if os.path.isdir(shim_dir) else None
35+
36+
if not os.path.isdir(shim_dir):
37+
raise RuntimeError(
38+
f"Shim include directory not found at {shim_dir}. "
39+
"This indicates a packaging issue. Please reinstall ast_canopy."
40+
)
41+
42+
return shim_dir
3643

3744

3845
@dataclass
@@ -435,7 +442,7 @@ def parse_declarations_from_source(
435442
f"-std={cxx_standard}",
436443
f"-resource-dir={clang_resource_dir}",
437444
# Place shim include dir early so it can intercept vendor headers.
438-
*([f"-I{_get_shim_include_dir()}"] if _get_shim_include_dir() else []),
445+
f"-I{_get_shim_include_dir()}",
439446
# cuda_wrappers_dir precede libstdc++ search includes to shadow certain
440447
# libstdc++ headers
441448
f"-isystem{cuda_wrappers_dir}",

ast_canopy/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ readme = { file = "README.md", content-type = "text/markdown" }
1313
[tool.scikit-build]
1414
cmake.targets = ["pylibastcanopy"]
1515
wheel.license-files = ["../LICENSE"]
16+
# Be explicit about shipping header shims in sdists so wheels built from sdist
17+
# (e.g., on PyPI) always contain these files.
18+
sdist.include = ["ast_canopy/shim_include/**"]
1619

1720
[tool.scikit-build.metadata.version]
1821
provider = "scikit_build_core.metadata.setuptools_scm"

numbast/src/numbast/args.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
5+
from numba.cuda.target import CUDATargetContext
6+
from llvmlite import ir
7+
8+
9+
def prepare_ir_types(
10+
context: CUDATargetContext, argtys: list[ir.Type]
11+
) -> list[ir.Type]:
12+
"""
13+
Prepare IR types for passing arguments via pointers in function calls.
14+
15+
This utility wraps each argument type in a PointerType to enable
16+
the call convention used by FunctionCallConv, where arguments are
17+
passed by reference.
18+
19+
Parameters
20+
----------
21+
context : context object
22+
The compilation context providing the get_value_type method.
23+
argtys : list[ir.Type]
24+
List of LLVM IR types representing function arguments.
25+
26+
Returns
27+
-------
28+
list[ir.Type]
29+
List of pointer types wrapping the value types of each argument.
30+
"""
31+
return [ir.PointerType(context.get_value_type(argty)) for argty in argtys]

numbast/src/numbast/callconv.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from numbast.args import prepare_ir_types
2+
from numba.cuda import types, cgutils
3+
4+
from llvmlite import ir
5+
6+
7+
class BaseCallConv:
8+
shim_function_template = "{mangled_name}_nbst"
9+
10+
def __init__(
11+
self,
12+
itanium_mangled_name: str,
13+
shim_writer: object,
14+
shim_code: str,
15+
):
16+
self.shim_writer = shim_writer
17+
self.itanium_mangled_name = itanium_mangled_name
18+
self.shim_code = shim_code
19+
20+
self.shim_function_name = self.shim_function_template.format(
21+
mangled_name=self.itanium_mangled_name
22+
)
23+
24+
def _lazy_write_shim(self, shim_code: str):
25+
self.shim_writer.write_to_shim(shim_code, self.shim_function_name)
26+
27+
def _lower(self, builder, context, sig, args):
28+
self._lazy_write_shim(self.shim_code)
29+
return self._lower_impl(builder, context, sig, args)
30+
31+
def _lower_impl(self, builder, context, sig, args):
32+
raise NotImplementedError
33+
34+
def __call__(self, builder, context, sig, args):
35+
return self._lower(builder, context, sig, args)
36+
37+
38+
class FunctionCallConv(BaseCallConv):
39+
def __init__(
40+
self,
41+
itanium_mangled_name: str,
42+
shim_writer: object,
43+
shim_code: str,
44+
return_type: types.Type,
45+
):
46+
super().__init__(itanium_mangled_name, shim_writer, shim_code)
47+
self.return_type = return_type
48+
49+
def _lower_impl(self, builder, context, sig, args):
50+
# 1. Prepare return value pointer
51+
if self.return_type == types.void:
52+
# Void return type in C++ is shimmed as int& ignored
53+
retval_ty = ir.IntType(32)
54+
retval_ptr = builder.alloca(retval_ty, name="ignored")
55+
else:
56+
retval_ty = context.get_value_type(self.return_type)
57+
retval_ptr = builder.alloca(retval_ty, name="retval")
58+
59+
# 2. Prepare arguments
60+
arg_pointer_types = prepare_ir_types(context, sig.args)
61+
62+
# All arguments are passed by pointer
63+
ptrs = [
64+
cgutils.alloca_once(builder, context.get_value_type(argty))
65+
for argty in sig.args
66+
]
67+
for ptr, argty, arg in zip(ptrs, sig.args, args):
68+
builder.store(arg, ptr, align=getattr(argty, "alignof_", None))
69+
70+
# 3. Declare shim
71+
# Shim signature: int (retval_type*, arg0_type*, ...)
72+
fnty = ir.FunctionType(
73+
ir.IntType(32), [ir.PointerType(retval_ty)] + arg_pointer_types
74+
)
75+
fn = cgutils.get_or_insert_function(
76+
builder.module, fnty, self.shim_function_name
77+
)
78+
79+
# 4. Call shim
80+
builder.call(fn, (retval_ptr, *ptrs))
81+
82+
# 5. Return
83+
if self.return_type == types.void:
84+
return None
85+
else:
86+
return builder.load(
87+
retval_ptr, align=getattr(self.return_type, "alignof_", None)
88+
)

numbast/src/numbast/class_template.py

Lines changed: 19 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
CallableTemplate,
2222
)
2323
from numba.core.datamodel.models import StructModel, OpaqueModel
24-
from numba.cuda import declare_device
2524
from numba.cuda.cudadecl import register_global, register, register_attr
2625
from numba.cuda.cudaimpl import lower
2726
from numba.cuda.core.imputils import numba_typeref_ctor
@@ -44,25 +43,14 @@
4443
)
4544
from numbast.utils import (
4645
deduplicate_overloads,
47-
make_device_caller_with_nargs,
48-
assemble_arglist_string,
49-
assemble_dereferenced_params_string,
46+
make_struct_ctor_shim,
5047
make_struct_regular_method_shim,
5148
)
49+
from numbast.callconv import FunctionCallConv
5250
from numbast.shim_writer import ShimWriterBase
5351

5452
ConcreteTypeCache: dict[str, nbtypes.Type] = {}
5553

56-
ConcreteTypeCache2: dict[str, nbtypes.Type] = {}
57-
58-
struct_ctor_shim_layer_template = """
59-
extern "C" __device__ int
60-
{func_name}(int &ignore, {name} *self {arglist}) {{
61-
new (self) {name}({args});
62-
return 0;
63-
}}
64-
"""
65-
6654

6755
class MetaType(nbtypes.Type):
6856
def __init__(self, template_name):
@@ -111,65 +99,27 @@ def bind_cxx_struct_ctor(
11199
# Lowering
112100
# Note that libclang always consider the return type of a constructor
113101
# is void. So we need to manually specify the return type here.
114-
func_name = deduplicate_overloads(f"{ctor.mangled_name}_nbst")
115-
116-
# FIXME: temporary solution for mismatching function prototype against definition.
117-
# If params are passed by value, at prototype the signature of __nv_bfloat16 is set
118-
# to `b32` type, but to `b64` at definition, causing a linker error. A temporary solution
119-
# is to pass all params by pointer and dereference them in shim. See dereferencing at the
120-
# shim generation below.
121-
ctor_shim_decl = declare_device(
122-
func_name,
123-
nbtypes.int32(
124-
nbtypes.CPointer(s_type_ref.instance_type),
125-
*map(nbtypes.CPointer, param_types),
126-
),
127-
)
128-
129-
ctor_shim_call = make_device_caller_with_nargs(
130-
func_name + "_shim",
131-
1 + len(param_types), # the extra argument for placement new pointer
132-
ctor_shim_decl,
133-
)
102+
mangled_name = deduplicate_overloads(ctor.mangled_name)
103+
shim_func_name = f"{mangled_name}_nbst"
134104

135105
# Dynamically generate the shim layer:
136106
# FIXME: All params are passed by pointers, then dereferenced in shim.
137107
# temporary solution for mismatching function prototype against definition.
138108
# See above lowering for details.
139-
arglist = assemble_arglist_string(ctor.params)
109+
shim = make_struct_ctor_shim(
110+
shim_name=shim_func_name, struct_name=struct_name, params=ctor.params
111+
)
140112

141-
shim = struct_ctor_shim_layer_template.format(
142-
func_name=func_name,
143-
name=struct_name,
144-
arglist=arglist,
145-
args=assemble_dereferenced_params_string(ctor.params),
113+
ctor_cc = FunctionCallConv(
114+
mangled_name, shim_writer, shim, s_type_ref.instance_type
146115
)
147116

148117
@lower(numba_typeref_ctor, s_type_ref, *param_types)
149118
def ctor_impl(context, builder, sig, args):
150-
s_type = s_type_ref.instance_type
151-
# Delay writing the shim function at lowering time. This avoids writing
152-
# shim functions from the parsed header that's unused in kernels.
153-
shim_writer.write_to_shim(shim, func_name)
154-
155-
selfptr = builder.alloca(context.get_value_type(s_type), name="selfptr")
156-
argptrs = [
157-
builder.alloca(context.get_value_type(arg)) for arg in sig.args[1:]
158-
]
159-
for ptr, ty, arg in zip(argptrs, sig.args[1:], args[1:]):
160-
builder.store(arg, ptr, align=getattr(ty, "alignof_", None))
161-
162-
context.compile_internal(
163-
builder,
164-
ctor_shim_call,
165-
nb_signature(
166-
nbtypes.int32,
167-
nbtypes.CPointer(s_type),
168-
*map(nbtypes.CPointer, param_types),
169-
),
170-
(selfptr, *argptrs),
171-
)
172-
return builder.load(selfptr, align=getattr(s_type, "alignof_", None))
119+
# `numba_typeref_ctor` includes the typeref as the first argument; the
120+
# generated shim expects only the actual constructor params.
121+
ctor_sig = nb_signature(s_type_ref.instance_type, *param_types)
122+
return ctor_cc(builder, context, ctor_sig, args[1:])
173123

174124
return param_types
175125

@@ -236,44 +186,24 @@ def bind_cxx_struct_regular_method(
236186
)
237187

238188
# Lowering
239-
func_name = deduplicate_overloads(f"__{method_decl.mangled_name}_nbst")
240-
241-
c_sig = return_type(
242-
nbtypes.CPointer(s_type), *map(nbtypes.CPointer, param_types)
243-
)
244-
245-
shim_decl = declare_device(func_name, c_sig)
246-
247-
shim_call = make_device_caller_with_nargs(
248-
func_name + "_shim", 1 + len(param_types), shim_decl
249-
)
189+
mangled_name = deduplicate_overloads(f"__{method_decl.mangled_name}")
190+
shim_func_name = f"{mangled_name}_nbst"
250191

251192
shim = make_struct_regular_method_shim(
252-
shim_name=func_name,
193+
shim_name=shim_func_name,
253194
struct_name=struct_decl.name,
254195
method_name=method_decl.name,
255196
return_type=method_decl.return_type.unqualified_non_ref_type_name,
256197
params=method_decl.params,
257198
)
258199

200+
method_cc = FunctionCallConv(mangled_name, shim_writer, shim, return_type)
201+
259202
qualname = f"{s_type}.{method_decl.name}"
260203

261204
@lower(qualname, s_type, *param_types)
262205
def _method_impl(context, builder, sig, args):
263-
shim_writer.write_to_shim(shim, func_name)
264-
265-
argptrs = [
266-
builder.alloca(context.get_value_type(arg)) for arg in sig.args
267-
]
268-
for ptr, ty, arg in zip(argptrs, sig.args, args):
269-
builder.store(arg, ptr, align=getattr(ty, "alignof_", None))
270-
271-
return context.compile_internal(
272-
builder,
273-
shim_call,
274-
c_sig,
275-
argptrs,
276-
)
206+
return method_cc(builder, context, sig, args)
277207

278208
return nb_signature(return_type, *param_types, recvr=s_type)
279209

0 commit comments

Comments
 (0)