Skip to content

Commit 9b35a2e

Browse files
📝 Add docstrings to numbast-pass-by-reference (#282)
Docstrings generation was requested by @isVoid. * #278 (comment) The following files were modified: * `numbast/src/numbast/args.py` * `numbast/src/numbast/callconv.py` * `numbast/src/numbast/class_template.py` * `numbast/src/numbast/function.py` * `numbast/src/numbast/intent.py` * `numbast/src/numbast/static/callconv.py` * `numbast/src/numbast/static/function.py` * `numbast/src/numbast/static/struct.py` * `numbast/src/numbast/static/tests/conftest.py` * `numbast/src/numbast/static/tests/data/src/function_out.cu` * `numbast/src/numbast/static/tests/test_function_static_bindings.py` * `numbast/src/numbast/static/types.py` * `numbast/src/numbast/struct.py` * `numbast/src/numbast/tools/static_binding_generator.py` * `numbast/src/numbast/types.py` * `numbast/tests/test_function.py` <details> <summary>These file types are not supported</summary> * `.pre-commit-config.yaml` * `numbast/src/numbast/static/tests/data/function_out.cuh` * `numbast/tests/data/sample_function_mutative.cuh` * `numbast/tests/data/sample_function_out.cuh` </details> <details> <summary>ℹ️ Note</summary><blockquote> CodeRabbit cannot perform edits on its own pull requests yet. </blockquote></details> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: Michael Wang <isVoid@users.noreply.github.com>
1 parent 2a90a90 commit 9b35a2e

File tree

17 files changed

+682
-269
lines changed

17 files changed

+682
-269
lines changed

numbast/src/numbast/args.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,22 @@ def prepare_ir_types(
1313
pass_ptr_mask: list[bool] | None = None,
1414
) -> list[ir.Type]:
1515
"""
16-
Prepare IR types for passing arguments via pointers in function calls.
17-
18-
This utility wraps each argument type in a PointerType to enable
19-
the call convention used by FunctionCallConv, where arguments are
20-
passed by reference.
21-
22-
Parameters
23-
----------
24-
context : context object
25-
The compilation context providing the get_value_type method.
26-
argtys : list[ir.Type]
27-
List of LLVM IR types representing function arguments.
28-
29-
Returns
30-
-------
31-
list[ir.Type]
32-
List of pointer types wrapping the value types of each argument.
16+
Prepare LLVM IR types for passing function arguments by reference.
17+
18+
Given a list of argument IR types, return a parallel list of IR types suitable for an ABI that passes arguments by pointer. For each argument, the context's get_value_type() is used to obtain the value type; if the corresponding entry in pass_ptr_mask is True and that value type is already an ir.PointerType, that pointer type is preserved, otherwise the value type is wrapped in an ir.PointerType.
19+
20+
Parameters:
21+
context (CUDATargetContext): Compilation context used to obtain the value type via get_value_type().
22+
argtys (list[ir.Type]): Argument IR types to prepare.
23+
pass_ptr_mask (list[bool] | None): Optional mask the same length as argtys indicating per-argument behavior.
24+
If None, all entries are treated as False. When True for an argument and the value type is an ir.PointerType,
25+
the pointer type is passed through unchanged.
26+
27+
Returns:
28+
list[ir.Type]: Prepared IR types where each entry is either a pointer-to-value or an existing pointer type preserved per pass_ptr_mask.
29+
30+
Raises:
31+
ValueError: If pass_ptr_mask is provided and its length does not match len(argtys).
3332
"""
3433
if pass_ptr_mask is None:
3534
pass_ptr_mask = [False] * len(argtys)

numbast/src/numbast/callconv.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,18 @@ def __init__(
5353
out_return_types: list[types.Type] | None = None,
5454
cxx_return_type: types.Type | None = None,
5555
):
56+
"""
57+
Initialize a FunctionCallConv with shim information and optional ABI/intent hints.
58+
59+
Parameters:
60+
itanium_mangled_name (str): The Itanium-mangled C++ function name used to derive the shim name.
61+
shim_writer (object): Writer used to emit the shim code when required.
62+
shim_code (str): LLVM/IR code template for the shim.
63+
arg_is_ref (list[bool] | None): Per-argument mask indicating whether an argument should be passed as a pointer (True) or by value (False). If None, pointer-passing is determined later or defaults to all False.
64+
intent_plan (IntentPlan | None): Optional plan describing visible parameter indices, which parameters should be passed as pointers, and which parameters are out-returns; when present it drives argument mapping and out-return handling.
65+
out_return_types (list[types.Type] | None): Types of the out-return values in the order declared by the IntentPlan; required when the intent_plan defines out-return indices.
66+
cxx_return_type (types.Type | None): The C++ ABI return type to use for allocating/shimming the return slot; if None, the signature's return type is used.
67+
"""
5668
super().__init__(itanium_mangled_name, shim_writer, shim_code)
5769
self._arg_is_ref = list(arg_is_ref) if arg_is_ref is not None else None
5870
self._intent_plan = intent_plan
@@ -64,6 +76,24 @@ def __init__(
6476
def _lower_impl(self, builder, context, sig, args):
6577
# Numba-visible return type may differ from the underlying C++ return type
6678
# when out_return parameters are enabled (tuple returns, etc.).
79+
"""
80+
Lower the configured call into a shim invocation, preparing return and argument pointers according to arg_is_ref or an IntentPlan and materializing the final Numba-visible return value.
81+
82+
Parameters:
83+
builder: LLVM IR builder used to emit allocations, stores, and calls.
84+
context: Compilation context used to map numba types to LLVM value types and to construct tuple return values.
85+
sig: Numba function signature describing the visible parameter and return types.
86+
args: Sequence of LLVM IR values corresponding to the visible signature parameters.
87+
88+
Returns:
89+
The loaded return value(s) according to the signature and intent plan:
90+
- `None` if the effective C++ return type is void and no out-return values are present.
91+
- A single LLVM value for a single visible return.
92+
- A tuple object constructed via `context.make_tuple` when the visible return type is a tuple; the tuple contains the C++ return (if non-void) followed by any out-return values.
93+
94+
Raises:
95+
ValueError: if the provided IntentPlan does not align with `sig` (mismatched visible_param_indices or pass_ptr_mask), if `out_return_types` are required but missing or length-mismatched, or if a non-tuple visible return is expected but multiple return values are produced.
96+
"""
6797
cxx_return_type = (
6898
self._cxx_return_type
6999
if self._cxx_return_type is not None

numbast/src/numbast/class_template.py

Lines changed: 61 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -85,27 +85,17 @@ def bind_cxx_struct_ctor(
8585
s_type_ref: nbtypes.TypeRef,
8686
shim_writer: ShimWriterBase,
8787
) -> Optional[list]:
88-
"""Create bindings for a C++ struct constructor and return its argument types.
88+
"""
89+
Bind a C++ struct constructor into Numba and return the constructor's argument types.
8990
90-
Parameters
91-
----------
91+
Parameters:
92+
ctor (StructMethod): The C++ constructor declaration to bind.
93+
struct_name (str): The name of the C++ struct being bound.
94+
s_type_ref (numba.types.TypeRef): The Numba TypeRef that represents the struct instantiation.
95+
shim_writer (ShimWriterBase): Writer used to emit the shim function for the constructor.
9296
93-
ctor : StructMethod
94-
Constructor declaration of struct in CXX
95-
struct_name : str
96-
The name of the struct from which this constructor belongs to
97-
s_type : numba.types.Type
98-
The Numba type of the struct
99-
S : object
100-
The Python API of the struct
101-
shim_writer : ShimWriterBase
102-
The shim writer to write the shim layer code.
103-
104-
Returns
105-
-------
106-
list of argument types, optional
107-
If the constructor is a move constructor, return ``None``. Otherwise,
108-
return the list of argument types.
97+
Returns:
98+
list: A list of Numba argument types for the constructor, or `None` if the constructor is a move constructor.
10999
"""
110100

111101
if ctor.is_move_constructor:
@@ -142,6 +132,14 @@ def bind_cxx_struct_ctor(
142132
def ctor_impl(context, builder, sig, args):
143133
# `numba_typeref_ctor` includes the typeref as the first argument; the
144134
# generated shim expects only the actual constructor params.
135+
"""
136+
Lowering implementation for a template-type constructor that delegates to the constructor call-convention with the instance type and actual constructor parameters.
137+
138+
This function builds a constructor signature whose first parameter is the concrete instance type and invokes the prepared FunctionCallConv, passing the original args with the leading typeref argument removed.
139+
140+
Returns:
141+
The value produced by the constructor call-convention (the constructed instance).
142+
"""
145143
ctor_sig = nb_signature(s_type_ref.instance_type, *param_types)
146144
return ctor_cc(builder, context, ctor_sig, args[1:])
147145

@@ -292,6 +290,23 @@ def bind_cxx_struct_regular_method(
292290
*,
293291
arg_intent: dict | None = None,
294292
) -> nb_signature:
293+
"""
294+
Bind a single C++ struct regular method to a Numba-callable signature and register its lowering.
295+
296+
Parameters:
297+
struct_decl (ClassTemplateSpecialization): Parsed C++ class template specialization for the method's declaring type.
298+
method_decl (StructMethod): Parsed method declaration describing name, parameters, and C++ return type.
299+
s_type (nbtypes.Type): The Numba type representing the struct instance (receiver) used in the generated signature.
300+
shim_writer (ShimWriterBase): Writer used to emit the native shim invoked by the lowering.
301+
arg_intent (dict | None, optional): Optional mapping of "TypeName.methodName" -> intent overrides. When provided,
302+
visible parameters, pointer-passing intents, and out-returns are derived from the overrides and may cause the
303+
resulting Numba signature and return type to include out-return values or pointer-wrapped parameters.
304+
305+
Returns:
306+
nb_signature: The Numba signature for the bound method (including the receiver as `recvr`). The signature's return
307+
type reflects any out-return promotion caused by `arg_intent` overrides; otherwise it matches the method's C++
308+
return type.
309+
"""
295310
cxx_return_type = to_numba_type(
296311
method_decl.return_type.unqualified_non_ref_type_name
297312
)
@@ -400,11 +415,16 @@ def bind_cxx_struct_regular_methods(
400415
arg_intent: dict | None = None,
401416
) -> dict[str, ConcreteTemplate]:
402417
"""
418+
Collect concrete typing templates for all regular member functions of a C++ class template specialization.
403419
404-
Return
405-
------
420+
Parameters:
421+
struct_decl (ClassTemplateSpecialization): The parsed C++ class specialization declaration.
422+
s_type (nbtypes.Type): The Numba type representing the instantiated struct.
423+
shim_writer (ShimWriterBase): Writer used to emit shim code for bound methods.
424+
arg_intent (dict | None): Optional mapping of argument-intent overrides (keyed by method or parameter as consumed by the per-method binder); forwarded to individual method bindings.
406425
407-
Mapping from function names to list of signatures.
426+
Returns:
427+
dict[str, ConcreteTemplate]: Mapping from method name to a ConcreteTemplate whose cases are the collected Numba signatures for that method's overloads.
408428
"""
409429

410430
method_overloads: dict[str, list[nb_signature]] = defaultdict(list)
@@ -804,29 +824,25 @@ def bind_cxx_class_template_specialization(
804824
arg_intent: dict | None = None,
805825
) -> object:
806826
"""
807-
Create bindings for a C++ struct.
808-
809-
Parameters
810-
----------
811-
shim_writer : ShimWriterBase
812-
The shim writer to write the shim layer code.
813-
struct_decl : Struct
814-
Declaration of the struct type in CXX
815-
parent_type : nbtypes.Type, optional
816-
Parent type of the Python API, by default nbtypes.Type
817-
data_model : type, optional
818-
Data model for the struct, by default StructModel
819-
aliases : dict[str, list[str]], optional
820-
Mappings from the name of the struct to a list of aliases.
821-
For example in C++: typedef A B; typedef A C; then
822-
aliases = {"A": ["B", "C"]}
823-
824-
Returns
825-
-------
826-
S : object
827-
The Python API of the struct.
828-
shim: str
829-
The generated shim layer code for struct methods.
827+
Bind a C++ class template specialization into Numba and return the corresponding Numba instance type.
828+
829+
This registers the C++ name-to-Numba-type mapping, installs a StructModel for the instantiated type, registers attribute and method typing templates, and binds constructors so the specialization can be used from Numba.
830+
831+
Parameters:
832+
shim_writer: ShimWriterBase
833+
Utility used to emit shim-layer code for bound methods.
834+
struct_decl: ClassTemplateSpecialization
835+
Parsed C++ class template specialization declaration to bind.
836+
instance_type_ref: nbtypes.Type
837+
The Numba TypeRef representing the instantiated template; its .instance_type is the returned type.
838+
aliases: dict[str, list[str]], optional
839+
Optional name aliases for the C++ type (e.g., typedefs) to register to the same Numba type.
840+
arg_intent: dict | None, optional
841+
Optional per-argument intent overrides to influence method parameter/return typing.
842+
843+
Returns:
844+
nbtypes.Type
845+
The Numba instance type for the bound class template specialization (instance_type_ref.instance_type).
830846
"""
831847

832848
s_type = instance_type_ref.instance_type

numbast/src/numbast/function.py

Lines changed: 47 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,15 @@ def bind_cxx_operator_overload_function(
4848
*,
4949
arg_intent: dict | None = None,
5050
) -> object:
51-
"""Create bindings for a C++ operator-overload function.
52-
53-
Parameters
54-
----------
55-
shim_writer : ShimWriter
56-
The shim writer to write the generated shim layer code.
51+
"""
52+
Create a Numba-callable binding for a C++ operator-overload function.
5753
58-
func_decl : Function
59-
The declaration of the function in C++.
54+
Parameters:
55+
func_decl (Function): C++ function declaration to bind.
56+
arg_intent (dict | None): Optional mapping that customizes argument intent (e.g., which reference parameters are treated as input, output, or inout). When provided, intent controls visible parameter/pointer treatment and out-return composition.
6057
61-
Returns
62-
-------
63-
shim_call : object
64-
The Numba-CUDA-callable Python API for the function.
58+
Returns:
59+
FunctionCallConv | None: A callable wrapper used during lowering that performs the bound call, or `None` when the operator is unsupported (e.g., copy assignment operators).
6560
"""
6661
if func_decl.is_copy_assignment_operator():
6762
# copy assignment operator, do not support in Numba / Python, skip
@@ -110,6 +105,18 @@ class op_decl(ConcreteTemplate):
110105

111106
@lower(py_op, *param_types)
112107
def impl(context, builder, sig, args):
108+
"""
109+
Delegate lowering to the captured FunctionCallConv instance `func_cc`.
110+
111+
Parameters:
112+
context: Numba lowering context used during compilation.
113+
builder: LLVM IR builder used to emit instructions.
114+
sig: The function signature being lowered.
115+
args: Sequence of lowered argument values passed to the call.
116+
117+
Returns:
118+
The lowered native value(s) produced by `func_cc`.
119+
"""
113120
return func_cc(builder, context, sig, args)
114121

115122
return func_cc
@@ -123,27 +130,28 @@ def bind_cxx_non_operator_function(
123130
*,
124131
arg_intent: dict | None = None,
125132
) -> object:
126-
"""Create bindings for a C++ non-operator function.
133+
"""
134+
Create a Python-callable binding for a C++ non-operator function.
135+
136+
Optionally uses an arg_intent override to control which C++ reference parameters are exposed as pointer parameters or returned as out-returns; when no overrides are provided, reference parameters are treated as input-only values.
127137
128138
Parameters
129139
----------
130140
shim_writer : ShimWriter
131-
The shim writer to write the generated shim layer code.
132-
141+
Writer used to emit the generated shim layer code.
133142
func_decl : Function
134-
The declaration of the function in C++.
135-
143+
C++ function declaration to bind.
136144
skip_prefix : str | None
137-
Skip functions with this prefix. Has no effect if None or empty.
138-
145+
If provided, skip functions whose names start with this prefix.
139146
exclude : set[str]
140-
A set of function names to exclude.
141-
147+
Set of function names to exclude from binding.
148+
arg_intent : dict | None, optional
149+
Optional per-function intent overrides that specify visibility and in/out semantics for reference parameters.
142150
143151
Returns
144152
-------
145-
func : object
146-
The Python-callable API for the function.
153+
object
154+
The Python-callable function object registered for the bound C++ function, or `None` if the function is skipped.
147155
"""
148156
global overload_registry
149157

@@ -262,29 +270,22 @@ def bind_cxx_function(
262270
*,
263271
arg_intent: dict | None = None,
264272
) -> object:
265-
"""Create bindings for a C++ function.
266-
267-
Parameters
268-
----------
269-
shim_writer : ShimWriter
270-
The shim writer to write the generated shim layer code.
271-
272-
func_decl : Function
273-
Declaration of the function in CXX
274-
275-
skip_prefix : str | None
276-
Skip functions with this prefix. Has no effect if None or empty.
277-
278-
skip_non_device : bool
279-
Skip non device functions. Default to True.
280-
281-
exclude : set[str]
282-
A set of function names to exclude. Default to empty set.
283-
284-
Returns
285-
-------
286-
func : object
287-
The Numba-CUDA-callable Python API for the function.
273+
"""
274+
Create Python bindings for a C++ function.
275+
276+
Parameters:
277+
shim_writer (ShimWriter): Writer that emits the generated C/C++ shim code.
278+
func_decl (Function): C++ function declaration to bind.
279+
skip_prefix (str | None): If provided, skip functions whose names start with this prefix.
280+
skip_non_device (bool): If True, skip functions not marked for device or host_device execution.
281+
exclude (set[str]): Names of functions to exclude from binding.
282+
arg_intent (dict | None): Optional explicit intent overrides that control which C++ reference
283+
parameters are exposed as inputs, outputs, or inout pointers and which parameters are
284+
promoted to out-returns.
285+
286+
Returns:
287+
object or None: The Numba-CUDA-callable Python binding object for the function, or `None`
288+
if the function is skipped or not exposed.
288289
"""
289290

290291
if skip_non_device and func_decl.exec_space not in {

0 commit comments

Comments
 (0)