Skip to content

Commit d482dbc

Browse files
committed
address review comments
1 parent aa8791e commit d482dbc

File tree

4 files changed

+22
-15
lines changed

4 files changed

+22
-15
lines changed

ast_canopy/ast_canopy/decl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def from_c_obj(cls, c_obj: bindings.Method, parse_entry_point: str):
300300

301301

302302
class TemplatedStructMethod(StructMethod):
303-
"""Struct/class method who's name may include template parameters.
303+
"""Struct/class method whose name may include template parameters.
304304
305305
Provides utilities for working with the declaration name without
306306
template arguments.

numbast/src/numbast/function_template.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ast_canopy.pylibastcanopy import execution_space
1313
from ast_canopy.decl import FunctionTemplate
1414

15-
from numba import types as nbtypes
15+
from numba.cuda import types as nbtypes
1616
from numba.cuda.typing import signature as nb_signature
1717
from numba.cuda.typing.templates import AbstractTemplate
1818
from numba.cuda.cudadecl import register, register_global
@@ -38,6 +38,10 @@ def func():
3838
return func
3939

4040

41+
# Registry key: (name, intent_key, shim_writer). name is str; intent_key is tuple
42+
# or None for intent/overload dispatch; shim_writer is compared by identity so
43+
# different writer instances get separate entries; make_new_func_obj is the
44+
# default factory.
4145
func_obj_registry: dict[tuple[str, tuple | None, object], object] = defaultdict(
4246
make_new_func_obj
4347
)

numbast/src/numbast/overload_selection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ def _select_templated_overload(
2121
"""
2222
Select a FunctionTemplate overload for a templated function/method.
2323
24+
kwds is reserved for future C++-style overload resolution (keyword argument
25+
dispatch) and is intentionally unused today.
26+
2427
Today we only select by explicit argument count (visible arity). Keep this
2528
logic centralized so we can expand it with C++-style overload resolution:
2629
- filter viable candidates (arity/defaults/variadics, arg_intent visibility),

numbast/tests/test_function_template.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
@pytest.fixture
17-
def _sample_function_templates():
17+
def sample_function_templates():
1818
DATA_FOLDER = os.path.join(os.path.dirname(__file__), "data")
1919
p = os.path.join(DATA_FOLDER, "sample_function_template.cuh")
2020
decls = parse_declarations_from_source(p, [p], "sm_80", verbose=True)
@@ -37,8 +37,8 @@ def find_binding(bindings, name):
3737
)
3838

3939

40-
def test_templated_function_overload_selection(_sample_function_templates):
41-
func_bindings, shim_writer = _sample_function_templates
40+
def test_templated_function_overload_selection(sample_function_templates):
41+
func_bindings, shim_writer = sample_function_templates
4242
add = find_binding(func_bindings, "add")
4343

4444
@cuda.jit(link=shim_writer.links())
@@ -55,8 +55,8 @@ def kernel(x, y, z, out):
5555
np.testing.assert_allclose(out, np.array([3.0, 6.0], dtype=np.float32))
5656

5757

58-
def test_templated_function_explicit_specialization(_sample_function_templates):
59-
func_bindings, shim_writer = _sample_function_templates
58+
def test_templated_function_explicit_specialization(sample_function_templates):
59+
func_bindings, shim_writer = sample_function_templates
6060
add = find_binding(func_bindings, "add")
6161

6262
@cuda.jit(link=shim_writer.links())
@@ -76,8 +76,8 @@ def kernel(int_a, int_b, float_a, float_b, out_int, out_float):
7676
np.testing.assert_allclose(out_float, np.array([4.0], dtype=np.float32))
7777

7878

79-
def test_templated_function_default_non_type(_sample_function_templates):
80-
func_bindings, shim_writer = _sample_function_templates
79+
def test_templated_function_default_non_type(sample_function_templates):
80+
func_bindings, shim_writer = sample_function_templates
8181
add_default = find_binding(func_bindings, "add_default")
8282

8383
@cuda.jit(link=shim_writer.links())
@@ -91,8 +91,8 @@ def kernel(inp, out):
9191
assert out[0] == 17
9292

9393

94-
def test_templated_function_default_type(_sample_function_templates):
95-
func_bindings, shim_writer = _sample_function_templates
94+
def test_templated_function_default_type(sample_function_templates):
95+
func_bindings, shim_writer = sample_function_templates
9696
add_default_type = find_binding(func_bindings, "add_default_type")
9797

9898
@cuda.jit(link=shim_writer.links())
@@ -107,9 +107,9 @@ def kernel(inp, out):
107107

108108

109109
def test_templated_function_multiple_template_args(
110-
_sample_function_templates,
110+
sample_function_templates,
111111
):
112-
func_bindings, shim_writer = _sample_function_templates
112+
func_bindings, shim_writer = sample_function_templates
113113
add_cast = find_binding(func_bindings, "add_cast")
114114

115115
@cuda.jit(link=shim_writer.links())
@@ -125,9 +125,9 @@ def kernel(int_a, float_b, out):
125125

126126

127127
def test_templated_function_type_and_non_type(
128-
_sample_function_templates,
128+
sample_function_templates,
129129
):
130-
func_bindings, shim_writer = _sample_function_templates
130+
func_bindings, shim_writer = sample_function_templates
131131
add_with_non_type = find_binding(func_bindings, "add_with_non_type")
132132

133133
@cuda.jit(link=shim_writer.links())

0 commit comments

Comments
 (0)