Skip to content

Commit 2a90a90

Browse files
authored
Add FunctionTemplate Generation (#290)
This PR generates Numba bindings for C++ function templates. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added C++ templated function binding support for Numba CUDA, enabling users to bind generic C++ functions with type deduction and overload selection. * **Refactor** * Extracted overload selection logic into a dedicated internal module for improved code organization. * **Tests** * Added comprehensive test module and sample CUDA header with templated functions covering specializations, default arguments, and out-parameters. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Michael Wang <isVoid@users.noreply.github.com>
1 parent 2fddfa3 commit 2a90a90

File tree

7 files changed

+735
-62
lines changed

7 files changed

+735
-62
lines changed

numbast/src/numbast/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
bind_cxx_class_templates,
1111
)
1212
from numbast.function import bind_cxx_function, bind_cxx_functions
13+
from numbast.function_template import (
14+
bind_cxx_function_template,
15+
bind_cxx_function_templates,
16+
)
1317
from numbast.enum import bind_cxx_enum, bind_cxx_enums
1418
from numbast.shim_writer import MemoryShimWriter, FileShimWriter
1519

@@ -27,6 +31,8 @@
2731
"bind_cxx_enums",
2832
"bind_cxx_function",
2933
"bind_cxx_functions",
34+
"bind_cxx_function_template",
35+
"bind_cxx_function_templates",
3036
"bind_cxx_struct",
3137
"bind_cxx_structs",
3238
"bind_cxx_class_template_specialization",

numbast/src/numbast/class_template.py

Lines changed: 1 addition & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from numbast.callconv import FunctionCallConv
5757
from numbast.shim_writer import ShimWriterBase
5858
from numbast.deduction import deduce_templated_overloads
59+
from numbast.overload_selection import _select_templated_overload
5960

6061

6162
logger = logging.getLogger(__name__)
@@ -649,67 +650,6 @@ def _impl(
649650
return MergedTemplatedMethodDecl
650651

651652

652-
def _select_templated_overload(
653-
*,
654-
qualname: str,
655-
overloads: list[FunctionTemplate],
656-
param_types: tuple[nbtypes.Type, ...],
657-
kwds: dict[str, Any] | None = None,
658-
overrides: dict | None = None,
659-
) -> FunctionTemplate:
660-
"""
661-
Select a FunctionTemplate overload for a templated method.
662-
663-
Today we only select by explicit argument count (visible arity). Keep this
664-
logic centralized so we can expand it with C++-style overload resolution:
665-
- filter viable candidates (arity/defaults/variadics, arg_intent visibility),
666-
- rank implicit conversions (exact > promotion > standard > user-defined),
667-
- prefer better ref/cv binding and non-variadic over variadic,
668-
- prefer more specialized templates / stronger constraints,
669-
- treat remaining ties as ambiguous.
670-
"""
671-
arity = len(param_types)
672-
candidates: list[FunctionTemplate] = []
673-
intent_errors: list[Exception] = []
674-
675-
for m in overloads:
676-
if overrides is None:
677-
visible_arity = len(m.function.params)
678-
else:
679-
try:
680-
plan = compute_intent_plan(
681-
params=m.function.params,
682-
param_types=m.function.param_types,
683-
overrides=overrides,
684-
allow_out_return=True,
685-
)
686-
except Exception as exc:
687-
intent_errors.append(exc)
688-
continue
689-
visible_arity = len(plan.visible_param_indices)
690-
691-
if visible_arity == arity:
692-
candidates.append(m)
693-
694-
if overrides is not None and not candidates and intent_errors:
695-
raise TypeError(
696-
f"Failed to apply arg_intent overrides for {qualname}: "
697-
f"{intent_errors[0]}"
698-
)
699-
if not candidates:
700-
raise TypeError(
701-
f"No matching overload found for {qualname} with {arity} args. "
702-
f"Overload arities: {[len(m.function.params) for m in overloads]}"
703-
)
704-
if len(candidates) > 1:
705-
raise TypeError(
706-
f"Ambiguous overload for {qualname} with {arity} args. "
707-
f"Matching overload arities: {[len(m.function.params) for m in candidates]}"
708-
)
709-
710-
return candidates[0]
711-
712-
713653
_CXX_ARRAY_TYPE_RE = re.compile(r"^(?P<base>.*?)(?P<sizes>(\[[^\]]+\])+)\s*$")
714654

715655

0 commit comments

Comments
 (0)