Skip to content

Commit 4a71423

Browse files
committed
Add CUDA header mode and vendored cuda.h host test
1 parent 36ebaf6 commit 4a71423

File tree

4 files changed

+29724
-1
lines changed

4 files changed

+29724
-1
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ repos:
3131
rev: 'v20.1.0'
3232
hooks:
3333
- id: clang-format
34+
exclude: |
35+
(?x)^(
36+
ast_canopy/tests/host/cuda_headers/include/cuda.h
37+
)$
3438
- repo: https://github.com/codespell-project/codespell
3539
rev: v2.4.1
3640
hooks:

ast_canopy/ast_canopy/api.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@
2727

2828
logger = logging.getLogger(f"AST_Canopy.{__name__}")
2929

30+
CXX_FLAG_SETS: dict[str, list[str]] = {
31+
"cuda-header-parsing-flags": [
32+
"-nocudainc",
33+
"-nocudalib",
34+
]
35+
}
36+
3037

3138
def _get_shim_include_dir() -> str:
3239
"""Return the absolute path to the local shim include directory"""
@@ -57,6 +64,11 @@ def paths_to_include_flags(paths: list[str]) -> list[str]:
5764
return [f"-I{path}" for path in paths]
5865

5966

67+
def _is_cuda_header(source_file_path: str) -> bool:
68+
header_name = os.path.basename(source_file_path)
69+
return header_name.startswith("cuda") and header_name.endswith(".h")
70+
71+
6072
def get_default_compiler_search_paths(clang_binary: str | None) -> list[str]:
6173
"""Compile an empty CUDA file with the given clang binary in verbose mode and parse the
6274
output to extract the default system header search paths.
@@ -343,6 +355,7 @@ def parse_declarations_from_source(
343355
defines: list[str] = [],
344356
verbose: bool = False,
345357
bypass_parse_error: bool = False,
358+
cuda_header_mode: bool = False,
346359
) -> Declarations:
347360
"""Given a source file, parse all top-level declarations from it and return
348361
a ``Declarations`` object containing lists of declaration objects found in
@@ -384,6 +397,12 @@ def parse_declarations_from_source(
384397
bypass_parse_error : bool, optional
385398
If True, bypass parse error and continue generating bindings.
386399
400+
cuda_header_mode : bool, optional
401+
If True, enable ``cuda-header-parsing-flags`` when parsing CUDA
402+
headers (e.g. ``cuda.h``). This mode disables Clang's implicit CUDA
403+
include injection so declarations are attributed to the provided
404+
header path.
405+
387406
Returns
388407
-------
389408
Declarations
@@ -405,6 +424,22 @@ def parse_declarations_from_source(
405424
if not os.path.exists(source_file_path):
406425
raise FileNotFoundError(f"File not found: {source_file_path}")
407426

427+
cuda_header_parsing_flags: list[str] = []
428+
if cuda_header_mode and _is_cuda_header(source_file_path):
429+
cuda_header_parsing_flags = CXX_FLAG_SETS["cuda-header-parsing-flags"]
430+
431+
# Prefer the source header directory so include resolution keeps the
432+
# declaration locations anchored to the vendored CUDA header.
433+
source_include_dir = os.path.dirname(source_file_path)
434+
if (
435+
source_include_dir
436+
and source_include_dir not in cudatoolkit_include_dirs
437+
):
438+
cudatoolkit_include_dirs = [
439+
source_include_dir,
440+
*cudatoolkit_include_dirs,
441+
]
442+
408443
for p in additional_includes:
409444
if not isinstance(p, str):
410445
raise TypeError(f"Additional include path must be a string: {p}")
@@ -435,8 +470,8 @@ def parse_declarations_from_source(
435470
command_line_options = [
436471
"clang++",
437472
*clang_verbose_flag,
438-
"--cuda-device-only",
439473
"-xcuda",
474+
*cuda_header_parsing_flags,
440475
f"--cuda-path={cuda_path}",
441476
f"--cuda-gpu-arch={compute_capability}",
442477
f"-std={cxx_standard}",

0 commit comments

Comments
 (0)