2727
2828logger = logging .getLogger (f"AST_Canopy.{ __name__ } " )
2929
30+ CXX_FLAG_SETS : dict [str , list [str ]] = {
31+ "cuda-header-parsing-flags" : [
32+ "-nocudainc" ,
33+ "-nocudalib" ,
34+ ]
35+ }
36+
3037
3138def _get_shim_include_dir () -> str :
3239 """Return the absolute path to the local shim include directory"""
@@ -57,6 +64,11 @@ def paths_to_include_flags(paths: list[str]) -> list[str]:
5764 return [f"-I{ path } " for path in paths ]
5865
5966
67+ def _is_cuda_header (source_file_path : str ) -> bool :
68+ header_name = os .path .basename (source_file_path )
69+ return header_name .startswith ("cuda" ) and header_name .endswith (".h" )
70+
71+
6072def get_default_compiler_search_paths (clang_binary : str | None ) -> list [str ]:
6173 """Compile an empty CUDA file with the given clang binary in verbose mode and parse the
6274 output to extract the default system header search paths.
@@ -343,6 +355,7 @@ def parse_declarations_from_source(
343355 defines : list [str ] = [],
344356 verbose : bool = False ,
345357 bypass_parse_error : bool = False ,
358+ cuda_header_mode : bool = False ,
346359) -> Declarations :
347360 """Given a source file, parse all top-level declarations from it and return
348361 a ``Declarations`` object containing lists of declaration objects found in
@@ -384,6 +397,12 @@ def parse_declarations_from_source(
384397 bypass_parse_error : bool, optional
385398 If True, bypass parse error and continue generating bindings.
386399
400+ cuda_header_mode : bool, optional
401+ If True, enable ``cuda-header-parsing-flags`` when parsing CUDA
402+ headers (e.g. ``cuda.h``). This mode disables Clang's implicit CUDA
403+ include injection so declarations are attributed to the provided
404+ header path.
405+
387406 Returns
388407 -------
389408 Declarations
@@ -405,6 +424,22 @@ def parse_declarations_from_source(
405424 if not os .path .exists (source_file_path ):
406425 raise FileNotFoundError (f"File not found: { source_file_path } " )
407426
427+ cuda_header_parsing_flags : list [str ] = []
428+ if cuda_header_mode and _is_cuda_header (source_file_path ):
429+ cuda_header_parsing_flags = CXX_FLAG_SETS ["cuda-header-parsing-flags" ]
430+
431+ # Prefer the source header directory so include resolution keeps the
432+ # declaration locations anchored to the vendored CUDA header.
433+ source_include_dir = os .path .dirname (source_file_path )
434+ if (
435+ source_include_dir
436+ and source_include_dir not in cudatoolkit_include_dirs
437+ ):
438+ cudatoolkit_include_dirs = [
439+ source_include_dir ,
440+ * cudatoolkit_include_dirs ,
441+ ]
442+
408443 for p in additional_includes :
409444 if not isinstance (p , str ):
410445 raise TypeError (f"Additional include path must be a string: { p } " )
@@ -435,8 +470,8 @@ def parse_declarations_from_source(
435470 command_line_options = [
436471 "clang++" ,
437472 * clang_verbose_flag ,
438- "--cuda-device-only" ,
439473 "-xcuda" ,
474+ * cuda_header_parsing_flags ,
440475 f"--cuda-path={ cuda_path } " ,
441476 f"--cuda-gpu-arch={ compute_capability } " ,
442477 f"-std={ cxx_standard } " ,
0 commit comments