Skip to content

Commit 2511738

Browse files
authored
Class Template - Method Template (Parent-Child pattern) Binding Generation (#286)
This is a demo PR that shows a "Parent-Child" pattern that shows an end-to-end use case of class template instantiation, template argument deduction and overload selection working hand in hand. - Add templated member-function bindings for class template specializations with overload selection and `arg_intent` support. - Generate templated-method shims with explicit argument formatting (array-ref/pointer handling) and intent-aware returns. - Improve template specialization metadata and typing (qualified/specialized names, enum NTTP names, CCCL `NullType`/opaque fallback, quieter parse defaults). - Add class-template unit tests and CCCL integration tests; run CCCL tests in conda and wheels CI. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Templated-method support for class templates with per-method argument-intent overrides * **Improvements** * Enum-aware rendering of template arguments * Reduced default parsing verbosity * Looser type deduction for opaque mappings * More consistent qualified template-name handling * **Tests** * New CCCL (CUB) test suite and CI step to run it * Expanded templated-class method tests * **Chores** * Updated CI/test scripts and pre-commit codespell ignores * **Documentation** * Added SPDX headers to third-party test modules <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Michael Wang <isVoid@users.noreply.github.com>
1 parent aea243e commit 2511738

21 files changed

+1263
-50
lines changed

.github/workflows/wheels-test.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,6 @@ jobs:
135135
- name: Test numbast wheels
136136
run: |
137137
python -m pytest numbast/
138+
- name: Run numbast_extensions third-party CCCL tests
139+
run: |
140+
python -m pytest numbast_extensions/tests/thirdparty/CCCL/

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ repos:
3737
- id: codespell
3838
additional_dependencies:
3939
- tomli
40-
args: ["--toml", "pyproject.toml", "--ignore-words-list", "inout"]
40+
args: ["--toml", "pyproject.toml", "--ignore-words-list", "inout,thirdparty"]
4141
- repo: https://github.com/google/yamlfmt
4242
rev: v0.16.0
4343
hooks:

ast_canopy/ast_canopy/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def parse_declarations_from_source(
341341
cxx_standard: str = "gnu++17",
342342
additional_includes: list[str] = [],
343343
defines: list[str] = [],
344-
verbose: bool = True,
344+
verbose: bool = False,
345345
bypass_parse_error: bool = False,
346346
) -> Declarations:
347347
"""Given a source file, parse all top-level declarations from it and return

ast_canopy/ast_canopy/decl.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,10 @@ def regular_member_functions(self):
391391
):
392392
yield m
393393

394+
def templated_member_functions(self):
395+
"""Generator for templated methods."""
396+
yield from self.templated_methods
397+
394398
@classmethod
395399
def from_c_obj(cls, c_obj: bindings.Record, parse_entry_point: str):
396400
return cls(
@@ -509,7 +513,7 @@ def value(self):
509513
return CXX_TYPE_TO_PYTHON_TYPE[cxx_type_name](self.value_serialized)
510514

511515

512-
class ClassTemplateSpecialization(Struct, ClassInstantiation):
516+
class ClassTemplateSpecialization(Struct):
513517
"""Represents a C++ class template specialization declaration.
514518
515519
Holds the underlying ``TemplatedStruct`` and provides ``instantiate`` for
@@ -535,7 +539,8 @@ def __init__(
535539
record.alignof_,
536540
record.parse_entry_point,
537541
)
538-
ClassInstantiation.__init__(self, class_template)
542+
543+
self._instantiation = ClassInstantiation(class_template)
539544

540545
targ_names: list[str] = [
541546
tp.name for tp in class_template.template_parameters
@@ -544,7 +549,7 @@ def __init__(
544549

545550
kwargs = dict(zip(targ_names, targ_values))
546551

547-
self.instantiate(**kwargs)
552+
self._instantiation.instantiate(**kwargs)
548553

549554
self.actual_template_arguments = actual_template_arguments
550555

@@ -560,11 +565,19 @@ def from_c_obj(
560565

561566
@property
562567
def name(self):
563-
return self.get_instantiated_c_stmt()
568+
return self._name
569+
570+
@property
571+
def specialized_name(self):
572+
return self._instantiation.get_instantiated_c_stmt(use_qual_name=False)
573+
574+
@property
575+
def qual_name(self):
576+
return self._instantiation.get_instantiated_c_stmt(use_qual_name=True)
564577

565578
@property
566579
def base_name(self):
567-
return self.record.name
580+
return self._name
568581

569582
def constructors(self):
570583
for m in self.methods:

ast_canopy/ast_canopy/instantiations.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,20 @@ def param_list(self):
3737
for tparam in self.template_parameters
3838
]
3939

40-
def get_instantiated_c_stmt(self) -> str:
41-
name = self.base_name
40+
def get_instantiated_c_stmt(self, use_qual_name: bool = False) -> str:
41+
if use_qual_name:
42+
name = self.qual_name
43+
else:
44+
name = self.base_name
45+
4246
param_list = self.param_list
4347

4448
flatten = []
4549
for param in param_list:
4650
if isinstance(param, BaseInstantiation):
47-
flatten.append(param.get_instantiated_c_stmt())
51+
flatten.append(
52+
param.get_instantiated_c_stmt(use_qual_name=use_qual_name)
53+
)
4854
else:
4955
flatten.append(str(param))
5056

@@ -56,6 +62,11 @@ def base_name(self):
5662
"BaseInstantiation.base_name is not implemented"
5763
)
5864

65+
def qual_name(self):
66+
raise NotImplementedError(
67+
"BaseInstantiation.qual_name is not implemented"
68+
)
69+
5970

6071
class FunctionInstantiation(BaseInstantiation):
6172
"""Represent an instantiation of a function template."""
@@ -68,6 +79,10 @@ def __init__(self, function_template: ast_canopy.decl.FunctionTemplate):
6879
def base_name(self):
6980
return self.function.name
7081

82+
@property
83+
def qual_name(self):
84+
return self.function.qual_name
85+
7186
def evaluate_constexpr_value(self, *args, header=None):
7287
if not self.function.is_constexpr:
7388
raise ValueError("Function is not constexpr")
@@ -124,3 +139,7 @@ def __init__(self, class_template: ast_canopy.decl.ClassTemplate):
124139
@property
125140
def base_name(self):
126141
return self.record.name
142+
143+
@property
144+
def qual_name(self):
145+
return self.record.qual_name

ast_canopy/cpp/src/class_template_specialization.cpp

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,50 @@ ClassTemplateSpecialization::ClassTemplateSpecialization(
2323

2424
for (auto i = 0; i < tparam_list.size(); i++) {
2525
const auto &targ = tparam_list[i];
26+
2627
clang::TemplateArgument::ArgKind kind = targ.getKind();
2728
switch (kind) {
2829
case clang::TemplateArgument::ArgKind::Type:
2930
actual_template_arguments.push_back(targ.getAsType().getAsString());
3031
break;
3132
case clang::TemplateArgument::ArgKind::Integral: {
32-
llvm::APSInt integer = targ.getAsIntegral();
33-
llvm::SmallString<32> str; // -uint64_t::max() is 21 digits (with sign).
34-
// This should be enough.
35-
integer.toString(str);
36-
actual_template_arguments.push_back(str.c_str());
33+
clang::QualType T = targ.getIntegralType();
34+
35+
// If this integral NTTP is actually an enum, try to recover the
36+
// enumerator name from the value.
37+
if (T->isEnumeralType()) {
38+
const auto *ET = T->getAs<clang::EnumType>();
39+
const clang::EnumDecl *ED = ET->getDecl();
40+
const llvm::APSInt &Val = targ.getAsIntegral();
41+
42+
const clang::EnumConstantDecl *Matched = nullptr;
43+
for (const auto *ECD : ED->enumerators()) {
44+
if (ECD->getInitVal() == Val) {
45+
Matched = ECD;
46+
break;
47+
}
48+
}
49+
50+
if (Matched) {
51+
// Use the fully-qualified enumerator name if you like,
52+
// or just ECD->getNameAsString() for the bare name.
53+
actual_template_arguments.push_back(
54+
Matched->getQualifiedNameAsString());
55+
} else {
56+
// No enumerator found with that value; fall back to integer.
57+
llvm::APSInt integer = Val;
58+
llvm::SmallString<32> str;
59+
integer.toString(str);
60+
actual_template_arguments.push_back(str.c_str());
61+
}
62+
} else {
63+
// Plain integral (not an enum type) — keep your existing behavior.
64+
llvm::APSInt integer = targ.getAsIntegral();
65+
llvm::SmallString<32> str; // -uint64_t::max() is 21 digits (with sign).
66+
integer.toString(str);
67+
actual_template_arguments.push_back(str.c_str());
68+
}
69+
3770
break;
3871
}
3972
default:

ast_canopy/tests/test_class_template_specialization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def test_ctpsd_unparsed_in_structs(sample_ctsd):
1919

2020
# At this stage, assert that the ctpsd is not parsed.
2121
assert len(ctsd) == 1
22-
assert ctsd[0].name == "BlockScan<int, 128>"
22+
assert ctsd[0].specialized_name == "BlockScan<int, 128>"
2323

2424
assert len(ctsd[0].fields) == 0
2525
assert len(ctsd[0].methods) == 2

ci/run_tests.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@ def run_pytest(lib, test_dir):
3838
@click.option(
3939
"--curand_device", is_flag=True, help="Run curand device binding pytests."
4040
)
41+
@click.option("--cccl", is_flag=True, help="Run CCCL (CUB) binding pytests.")
4142
@click.option("--all-tests", is_flag=True, help="Run all pytests.")
4243
def run(
4344
ast_canopy: bool,
4445
numbast: bool,
4546
bf16: bool,
4647
fp16: bool,
4748
curand_device: bool,
49+
cccl: bool,
4850
all_tests: bool,
4951
):
5052
"""Selectively run pytests in Numbast repo based on options provided.
@@ -53,7 +55,7 @@ def run(
5355
package. `--all-tests` option is mutually exclusive to all other options.
5456
"""
5557
if all_tests:
56-
if any([ast_canopy, numbast, bf16, fp16, curand_device]):
58+
if any([ast_canopy, numbast, bf16, fp16, curand_device, cccl]):
5759
raise ValueError(
5860
"`all_tests` and any subpackage specs are mutual exclusive."
5961
)
@@ -82,6 +84,8 @@ def run(
8284
)
8385
if all_tests or curand_device:
8486
run_pytest("curand_device", ["numbast_extensions/tests/test_curand.py"])
87+
if all_tests or cccl:
88+
run_pytest("cccl", ["numbast_extensions/tests/thirdparty/CCCL/"])
8589

8690

8791
if __name__ == "__main__":

ci/test_conda_python.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ trap "EXITCODE=1" ERR
4949
set +e
5050

5151
rapids-logger "Run Tests"
52-
# Debug print
53-
python ci/run_tests.py --ast-canopy --numbast
52+
python ci/run_tests.py --ast-canopy --numbast --cccl
5453

5554
rapids-logger "Test script exiting with value: $EXITCODE"
5655
exit ${EXITCODE}

0 commit comments

Comments
 (0)