Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ast_canopy/ast_canopy/decl.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,10 @@ def __init__(

self.parse_entry_point = parse_entry_point

@property
def name(self):
return self.record.name

@classmethod
def from_c_obj(cls, c_obj: bindings.ClassTemplate, parse_entry_point: str):
return cls(
Expand Down
45 changes: 45 additions & 0 deletions numbast/src/numbast/static/class_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from numbast.static.renderer import BaseRenderer, get_rendered_imports


class StaticClassTemplateRenderer(BaseRenderer):
_python_api_rendered: str

_python_api_template = """
def {class_template_name}():
pass
"""

def __init__(self, decl):
self.decl = decl

def _render_python_api(self):
self._python_api_rendered = self._python_api_template.format(
class_template_name=self.decl.name
)

def render(self):
self._render_python_api()


class StaticClassTemplatesRenderer(BaseRenderer):
def __init__(self, decls):
self.decls = decls

def _render(self, with_imports):
self._python_rendered = []

for decl in self.decls:
SCTR = StaticClassTemplateRenderer(decl)
SCTR.render()
self._python_rendered.append(SCTR._python_api_rendered)

self._python_str = ""

if with_imports:
self._python_str += "\n" + get_rendered_imports()

self._python_str += "\n" + "\n".join(self._python_rendered)

def render_as_str(self, with_imports: bool, with_shim_stream: bool) -> str:
self._render(with_imports)
return self._python_str
19 changes: 19 additions & 0 deletions numbast/src/numbast/static/tests/data/class_template.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
template <typename T, int BLOCK_DIM_X, int ITEMS_PER_THREAD> class BlockLoad {
public:
struct TempStorage {};

__device__ BlockLoad() {}
__device__ explicit BlockLoad(TempStorage &) {}

// Single-item load: just assign input to output
__device__ void Load(T input, T &output) { output = input; }

// Array load: copy ITEMS_PER_THREAD elements from input to output
__device__ void Load(T(input)[ITEMS_PER_THREAD],
T (&output)[ITEMS_PER_THREAD]) {
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
output[i] = input[i];
}
}
};
46 changes: 46 additions & 0 deletions numbast/src/numbast/static/tests/test_class_templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


import pytest


from ast_canopy import parse_declarations_from_source

from numbast.static.renderer import clear_base_renderer_cache, registry_setup
from numbast.static.class_template import (
StaticClassTemplatesRenderer,
)


@pytest.fixture(autouse=True, scope="module")
def cleanup():
clear_base_renderer_cache()


@pytest.fixture(scope="module")
def decl(data_folder, cleanup):
header = data_folder("class_template.cuh")

decls = parse_declarations_from_source(header, [header], "sm_50")
class_templates = decls.class_templates

assert len(class_templates) == 1

registry_setup(use_separate_registry=False)
SFR = StaticClassTemplatesRenderer(class_templates)

bindings = SFR.render_as_str(with_imports=True, with_shim_stream=True)
globals = {}
exec(bindings, globals)

print(bindings)
public_apis = ["BlockLoad"]
assert all(public_api in globals for public_api in public_apis)

return {k: globals[k] for k in public_apis}


def test_class_template(decl):
print(decl)
assert decl["BlockLoad"] is not None
18 changes: 17 additions & 1 deletion numbast/src/numbast/tools/static_binding_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from ast_canopy import parse_declarations_from_source
from ast_canopy.decl import Function, Struct
from pylibastcanopy import Enum, Typedef
from pylibastcanopy import Enum, Typedef, ClassTemplate

from numbast.static import reset_renderer
from numbast.static.renderer import (
Expand All @@ -36,6 +36,7 @@
)
from numbast.static.enum import StaticEnumsRenderer
from numbast.static.typedef import render_aliases
from numbast.static.class_template import StaticClassTemplatesRenderer
from numbast.tools.yaml_tags import string_constructor

config.CUDA_USE_NVIDIA_BINDING = True
Expand Down Expand Up @@ -415,6 +416,15 @@ def _generate_functions(
return SFR.render_as_str(with_imports=False, with_shim_stream=False)


def _generate_class_templates(
class_template_decls: list[ClassTemplate],
header_path: str,
):
"""Create class template bindings."""
SCTR = StaticClassTemplatesRenderer(class_template_decls)
return SCTR.render_as_str(with_imports=False, with_shim_stream=False)


def _generate_enums(enum_decls: list[Enum]):
"""Create enum bindings."""
SER = StaticEnumsRenderer(enum_decls)
Expand Down Expand Up @@ -509,6 +519,7 @@ def _static_binding_generator(
for td in decls.typedefs
if td.underlying_name not in config.exclude_structs
]
class_templates = decls.class_templates

if log_generates:
log_files_to_generate(functions, structs, enums, typedefs)
Expand All @@ -534,6 +545,11 @@ def _static_binding_generator(
config.skip_prefix,
)

_ = _generate_class_templates(
class_templates,
entry_point,
)

registry_setup_str = registry_setup(config.separate_registry)

if config.shim_include_override is not None:
Expand Down