Skip to content

Commit dfdc002

Browse files
committed
add host side function tests
1 parent aae2286 commit dfdc002

File tree

3 files changed

+95
-0
lines changed

3 files changed

+95
-0
lines changed

ast_canopy/pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ name = "ast_canopy"
1010
dynamic = ["version"]
1111
readme = { file = "README.md", content-type = "text/markdown" }
1212

13+
[project.optional-dependencies]
14+
cu13 = ["cuda-toolkit[cudart, crt, curand, cccl]"]
15+
16+
1317
[tool.scikit-build]
1418
cmake.targets = ["pylibastcanopy"]
1519
wheel.license-files = ["../LICENSE"]
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// clang-format off
2+
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
// SPDX-License-Identifier: Apache-2.0
4+
// clang-format on
5+
6+
int add(int a, int b) { return a + b; }
7+
8+
float scale(float value, float factor) { return value * factor; }
9+
10+
void set_value(int *out, int value) { *out = value; }
11+
12+
double __host__ host_offset(double x, double offset) { return x + offset; }
13+
14+
int __host__ __device__ add_host_device(int a, int b) { return a + b; }
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import os
5+
6+
import pytest
7+
8+
from ast_canopy import parse_declarations_from_source
9+
from ast_canopy.pylibastcanopy import execution_space
10+
11+
12+
@pytest.fixture(scope="module")
13+
def host_function_source():
14+
current_directory = os.path.dirname(os.path.abspath(__file__))
15+
return os.path.join(current_directory, "data", "sample_host_function.cpp")
16+
17+
18+
def test_parse_host_functions(host_function_source):
19+
decls = parse_declarations_from_source(
20+
host_function_source, [host_function_source], "sm_80"
21+
)
22+
23+
functions = decls.functions
24+
expected_names = {
25+
"add",
26+
"scale",
27+
"set_value",
28+
"host_offset",
29+
"add_host_device",
30+
}
31+
32+
assert len(functions) == len(expected_names)
33+
assert {func.name for func in functions} == expected_names
34+
35+
funcs_by_name = {func.name: func for func in functions}
36+
37+
add = funcs_by_name["add"]
38+
assert add.return_type.name == "int"
39+
assert [param.name for param in add.params] == ["a", "b"]
40+
assert [param.type_.name for param in add.params] == ["int", "int"]
41+
assert add.exec_space == execution_space.undefined
42+
43+
scale = funcs_by_name["scale"]
44+
assert scale.return_type.name == "float"
45+
assert [param.name for param in scale.params] == ["value", "factor"]
46+
assert [param.type_.name for param in scale.params] == ["float", "float"]
47+
assert scale.exec_space == execution_space.undefined
48+
49+
set_value = funcs_by_name["set_value"]
50+
assert set_value.return_type.name == "void"
51+
assert [param.name for param in set_value.params] == ["out", "value"]
52+
assert [param.type_.name for param in set_value.params] == [
53+
"int *",
54+
"int",
55+
]
56+
assert set_value.exec_space == execution_space.undefined
57+
58+
host_offset = funcs_by_name["host_offset"]
59+
assert host_offset.return_type.name == "double"
60+
assert [param.name for param in host_offset.params] == ["x", "offset"]
61+
assert [param.type_.name for param in host_offset.params] == [
62+
"double",
63+
"double",
64+
]
65+
assert host_offset.exec_space == execution_space.host
66+
67+
add_host_device = funcs_by_name["add_host_device"]
68+
assert add_host_device.return_type.name == "int"
69+
assert [param.name for param in add_host_device.params] == ["a", "b"]
70+
assert [param.type_.name for param in add_host_device.params] == [
71+
"int",
72+
"int",
73+
]
74+
assert add_host_device.exec_space == execution_space.host_device
75+
76+
for func in functions:
77+
assert func.parse_entry_point == host_function_source

0 commit comments

Comments
 (0)