Skip to content

Commit e0b331b

Browse files
authored
Capture Enum Underlying Type in AST_CANOPY, documenting the C++/Python Enum lowering (#267)
This PR adds feature to capture the c++ enum underlying integral type in ast_canopy. As well as capturing the handling of binding to function with enum argument in docs. This PR also tests whether `cuda.bindings.runtime.cudaRoundMode` works well with existing bindings. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Enum underlying integer types are exposed to Python bindings (includes cudaRoundMode); enum handling and registration improved. * **Refactor** * Call/return handling internals simplified (no user-visible API change); several type-resolution flows consolidated. * Public typing and metadata structures extended to include enum underlying-type information. * **Tests** * Added and updated tests exercising enum underlying types, device-side enum helpers, and streamlined binding-generation fixtures. * **Documentation** * FAQ expanded with notes on enum binding and lowering behavior. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Michael Wang <isVoid@users.noreply.github.com>
1 parent f91d941 commit e0b331b

26 files changed

+401
-216
lines changed

ast_canopy/ast_canopy/pylibastcanopy.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,24 @@ PYBIND11_MODULE(pylibastcanopy, m) {
4949
.def_readwrite("name", &Enum::name)
5050
.def_readwrite("enumerators", &Enum::enumerators)
5151
.def_readwrite("enumerator_values", &Enum::enumerator_values)
52+
.def_readwrite("underlying_type", &Enum::underlying_type)
5253
.def(py::pickle(
5354
[](const Enum &e) {
54-
return py::make_tuple(e.name, e.enumerators, e.enumerator_values);
55+
return py::make_tuple(e.name, e.enumerators, e.enumerator_values,
56+
e.underlying_type);
5557
},
5658
[](py::tuple t) {
57-
if (t.size() != 3)
59+
// Backward compat: older pickles only stored (name, enumerators,
60+
// enumerator_values).
61+
if (t.size() != 3 && t.size() != 4)
5862
throw std::runtime_error("Invalid enum state during unpickle!");
59-
return Enum{t[0].cast<std::string>(),
60-
t[1].cast<std::vector<std::string>>(),
61-
t[2].cast<std::vector<std::string>>()};
63+
Enum e{t[0].cast<std::string>(),
64+
t[1].cast<std::vector<std::string>>(),
65+
t[2].cast<std::vector<std::string>>()};
66+
if (t.size() == 4) {
67+
e.underlying_type = t[3].cast<Type>();
68+
}
69+
return e;
6270
}));
6371

6472
py::class_<Type>(m, "Type")

ast_canopy/ast_canopy/pylibastcanopy.pyi

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import collections.abc
2+
import typing
13
from _typeshed import Incomplete
24
from typing import ClassVar, overload
35

@@ -6,13 +8,19 @@ class ClassTemplate(Template):
68
record: Incomplete
79
def __init__(self, *args, **kwargs) -> None: ...
810

11+
class ClassTemplateSpecialization(Record):
12+
actual_template_arguments: list[str]
13+
class_template: ClassTemplate
14+
def __init__(self, *args, **kwargs) -> None: ...
15+
916
class ConstExprVar:
1017
name: str
1118
type_: Type
1219
value: str
1320
def __init__(self) -> None: ...
1421

1522
class Declarations:
23+
class_template_specializations: list[ClassTemplateSpecialization]
1624
class_templates: list[ClassTemplate]
1725
enums: list[Enum]
1826
function_templates: list[FunctionTemplate]
@@ -25,6 +33,7 @@ class Enum:
2533
enumerator_values: list[str]
2634
enumerators: list[str]
2735
name: str
36+
underlying_type: Incomplete
2837
def __init__(self, arg0) -> None: ...
2938

3039
class Field:
@@ -34,6 +43,7 @@ class Field:
3443
def __init__(self, *args, **kwargs) -> None: ...
3544

3645
class Function:
46+
attributes: set[str]
3747
exec_space: execution_space
3848
is_constexpr: bool
3949
mangled_name: str
@@ -73,7 +83,11 @@ class Record:
7383
class Template:
7484
num_min_required_args: int
7585
template_parameters: list[TemplateParam]
76-
def __init__(self, arg0: list[TemplateParam], arg1: int) -> None: ...
86+
def __init__(
87+
self,
88+
arg0: collections.abc.Sequence[TemplateParam],
89+
arg1: typing.SupportsInt,
90+
) -> None: ...
7791

7892
class TemplateParam:
7993
kind: template_param_kind
@@ -104,7 +118,7 @@ class access_kind:
104118
private_: ClassVar[access_kind] = ...
105119
protected_: ClassVar[access_kind] = ...
106120
public_: ClassVar[access_kind] = ...
107-
def __init__(self, value: int) -> None: ...
121+
def __init__(self, value: typing.SupportsInt) -> None: ...
108122
def __eq__(self, other: object) -> bool: ...
109123
def __hash__(self) -> int: ...
110124
def __index__(self) -> int: ...
@@ -123,7 +137,7 @@ class execution_space:
123137
host: ClassVar[execution_space] = ...
124138
host_device: ClassVar[execution_space] = ...
125139
undefined: ClassVar[execution_space] = ...
126-
def __init__(self, value: int) -> None: ...
140+
def __init__(self, value: typing.SupportsInt) -> None: ...
127141
def __eq__(self, other: object) -> bool: ...
128142
def __hash__(self) -> int: ...
129143
def __index__(self) -> int: ...
@@ -145,7 +159,7 @@ class method_kind:
145159
move_constructor: ClassVar[method_kind] = ...
146160
other: ClassVar[method_kind] = ...
147161
other_constructor: ClassVar[method_kind] = ...
148-
def __init__(self, value: int) -> None: ...
162+
def __init__(self, value: typing.SupportsInt) -> None: ...
149163
def __eq__(self, other: object) -> bool: ...
150164
def __hash__(self) -> int: ...
151165
def __index__(self) -> int: ...
@@ -162,7 +176,7 @@ class template_param_kind:
162176
non_type: ClassVar[template_param_kind] = ...
163177
template_: ClassVar[template_param_kind] = ...
164178
type_: ClassVar[template_param_kind] = ...
165-
def __init__(self, value: int) -> None: ...
179+
def __init__(self, value: typing.SupportsInt) -> None: ...
166180
def __eq__(self, other: object) -> bool: ...
167181
def __hash__(self) -> int: ...
168182
def __index__(self) -> int: ...
@@ -174,8 +188,10 @@ class template_param_kind:
174188
def value(self) -> int: ...
175189

176190
def parse_declarations_from_command_line(
177-
arg0: list[str], arg1: list[str]
191+
arg0: collections.abc.Sequence[str],
192+
arg1: collections.abc.Sequence[str],
193+
arg2: bool,
178194
) -> Declarations: ...
179195
def value_from_constexpr_vardecl(
180-
arg0: list[str], arg1: str
196+
arg0: collections.abc.Sequence[str], arg1: str
181197
) -> ConstExprVar | None: ...

ast_canopy/cpp/include/ast_canopy/ast_canopy.hpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,6 @@ enum class template_param_kind { type, non_type, template_ };
4040

4141
enum class access_kind { public_, protected_, private_ };
4242

43-
struct Enum {
44-
Enum(const std::string &name, const std::vector<std::string> &enumerators,
45-
const std::vector<std::string> &enumerator_values)
46-
: name(name), enumerators(enumerators),
47-
enumerator_values(enumerator_values) {}
48-
Enum(const clang::EnumDecl *);
49-
50-
std::string name;
51-
std::vector<std::string> enumerators;
52-
std::vector<std::string> enumerator_values;
53-
};
54-
5543
struct Type {
5644
Type() = default;
5745
Type(std::string name, std::string unqualified_non_ref_type_name,
@@ -69,6 +57,19 @@ struct Type {
6957
bool _is_left_reference;
7058
};
7159

60+
struct Enum {
61+
Enum(const std::string &name, const std::vector<std::string> &enumerators,
62+
const std::vector<std::string> &enumerator_values)
63+
: name(name), enumerators(enumerators),
64+
enumerator_values(enumerator_values) {}
65+
Enum(const clang::EnumDecl *);
66+
67+
std::string name;
68+
std::vector<std::string> enumerators;
69+
std::vector<std::string> enumerator_values;
70+
Type underlying_type;
71+
};
72+
7273
struct ConstExprVar {
7374
ConstExprVar() = default;
7475
ConstExprVar(const clang::VarDecl *VD);

ast_canopy/cpp/src/enum.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77

88
namespace ast_canopy {
99

10-
Enum::Enum(const clang::EnumDecl *ED) : name(ED->getNameAsString()) {
10+
Enum::Enum(const clang::EnumDecl *ED)
11+
: name(ED->getNameAsString()),
12+
underlying_type(ED->getIntegerType(), ED->getASTContext()) {
1113
for (const auto *enumerator : ED->enumerators()) {
1214
enumerators.push_back(enumerator->getNameAsString());
1315

ast_canopy/tests/data/sample_enum.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,7 @@ enum Foo { A = 1, B = 2, C = 3 };
77
enum NoDefaultFoo { D, E, F };
88

99
enum class Bar { A = 1 };
10+
11+
enum Fruit : uint64_t { Apple, Banana, Orange };
12+
enum Car : int16_t { Sedan, SUV, Pickup, Hatchback };
13+
enum Plane : char { Boeing, Airbus, Embraer };

ast_canopy/tests/test_parse_from_source.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ def test_load_enum(sample_enum_source, test_pickle):
467467
pickled = [pickle.dumps(e) for e in enums]
468468
enums = [pickle.loads(p) for p in pickled]
469469

470-
assert len(enums) == 3
470+
assert len(enums) == 6
471471
assert enums[0].name == "Foo"
472472
assert len(enums[0].enumerators) == 3
473473
assert enums[0].enumerators[0] == "A"
@@ -495,6 +495,24 @@ def test_load_enum(sample_enum_source, test_pickle):
495495
assert enums[2].enumerator_values[0] == "1"
496496

497497

498+
def test_load_enum_underlying_type(sample_enum_source, test_pickle):
499+
decls = parse_declarations_from_source(
500+
sample_enum_source, [sample_enum_source], "sm_80"
501+
)
502+
503+
enums = decls.enums
504+
if test_pickle:
505+
pickled = [pickle.dumps(e) for e in enums]
506+
enums = [pickle.loads(p) for p in pickled]
507+
508+
assert enums[3].name == "Fruit"
509+
assert enums[3].underlying_type.name == "uint64_t"
510+
assert enums[4].name == "Car"
511+
assert enums[4].underlying_type.name == "int16_t"
512+
assert enums[5].name == "Plane"
513+
assert enums[5].underlying_type.name == "char"
514+
515+
498516
def test_load_struct_function_execution_space(
499517
sample_execution_space_source, test_pickle
500518
):

docs/source/faq.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,26 @@ bindings generated with a specific version of Numbast are tested against a speci
167167
.. note::
168168
These version restrictions may be relaxed or removed once ``numba-cuda`` releases a stable 1.0 version with
169169
stabilized public APIs. Until then, bindings are tested against specific version ranges to ensure compatibility.
170+
171+
172+
C++ Enum Binding Generation Notes
173+
---------------------------------
174+
175+
**Why do Numbast bindings treat C++ enums as ``int64`` in Numba?**
176+
177+
Numba represents Python ``IntEnum`` values using ``IntEnumMember(..., int64)`` (i.e., enum values are lowered as
178+
64-bit integers). Numbast follows this convention in both dynamic and static binding generation so that Python-side
179+
typing and lowering are consistent.
180+
181+
**But C++ enums can have different underlying integer types. Why don't we track and truncate to that type in lowering?**
182+
183+
Numbast does not keep a per-enum “underlying integer type” registry and does not perform explicit truncation during
184+
lowering because the device-side shim is compiled by NVRTC, and the shim call site is where C++ type checking happens.
185+
Even though the Python/Numba side lowers enum values as 64-bit integers, NVRTC can resolve the target enum type and emit
186+
the appropriate conversion when the shim calls the original function that takes the C++ enum parameter. This means we
187+
don't need to track per-enum underlying integer types in Python or add special-case truncation/casting logic in Numbast
188+
lowering.
189+
190+
If you are binding code that depends on unusual enum representations or non-standard ABIs, you may need a custom
191+
adapter. For typical CUDA device code, this approach keeps the implementation simpler and avoids maintaining extra
192+
metadata for every enum type.

numbast/src/numbast/args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
54
from numba.cuda.target import CUDATargetContext
5+
66
from llvmlite import ir
77

88

numbast/src/numbast/callconv.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,24 +36,15 @@ def __call__(self, builder, context, sig, args):
3636

3737

3838
class FunctionCallConv(BaseCallConv):
39-
def __init__(
40-
self,
41-
itanium_mangled_name: str,
42-
shim_writer: object,
43-
shim_code: str,
44-
return_type: types.Type,
45-
):
46-
super().__init__(itanium_mangled_name, shim_writer, shim_code)
47-
self.return_type = return_type
48-
4939
def _lower_impl(self, builder, context, sig, args):
40+
return_type = sig.return_type
5041
# 1. Prepare return value pointer
51-
if self.return_type == types.void:
42+
if return_type == types.void:
5243
# Void return type in C++ is shimmed as int& ignored
5344
retval_ty = ir.IntType(32)
5445
retval_ptr = builder.alloca(retval_ty, name="ignored")
5546
else:
56-
retval_ty = context.get_value_type(self.return_type)
47+
retval_ty = context.get_value_type(return_type)
5748
retval_ptr = builder.alloca(retval_ty, name="retval")
5849

5950
# 2. Prepare arguments
@@ -80,9 +71,9 @@ def _lower_impl(self, builder, context, sig, args):
8071
builder.call(fn, (retval_ptr, *ptrs))
8172

8273
# 5. Return
83-
if self.return_type == types.void:
74+
if return_type == types.void:
8475
return None
8576
else:
8677
return builder.load(
87-
retval_ptr, align=getattr(self.return_type, "alignof_", None)
78+
retval_ptr, align=getattr(return_type, "alignof_", None)
8879
)

numbast/src/numbast/class_template.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,7 @@ def bind_cxx_struct_ctor(
110110
shim_name=shim_func_name, struct_name=struct_name, params=ctor.params
111111
)
112112

113-
ctor_cc = FunctionCallConv(
114-
mangled_name, shim_writer, shim, s_type_ref.instance_type
115-
)
113+
ctor_cc = FunctionCallConv(mangled_name, shim_writer, shim)
116114

117115
@lower(numba_typeref_ctor, s_type_ref, *param_types)
118116
def ctor_impl(context, builder, sig, args):
@@ -197,7 +195,7 @@ def bind_cxx_struct_regular_method(
197195
params=method_decl.params,
198196
)
199197

200-
method_cc = FunctionCallConv(mangled_name, shim_writer, shim, return_type)
198+
method_cc = FunctionCallConv(mangled_name, shim_writer, shim)
201199

202200
qualname = f"{s_type}.{method_decl.name}"
203201

0 commit comments

Comments
 (0)