diff --git a/flashinfer/__main__.py b/flashinfer/__main__.py index 679d8766bc..37b08a3a95 100644 --- a/flashinfer/__main__.py +++ b/flashinfer/__main__.py @@ -26,7 +26,14 @@ ) from .jit import clear_cache_dir, jit_spec_registry from .jit.cubin_loader import FLASHINFER_CUBINS_REPOSITORY -from .jit.env import FLASHINFER_CACHE_DIR, FLASHINFER_CUBIN_DIR +from .jit.env import ( + FLASHINFER_CACHE_DIR, + FLASHINFER_CUBIN_DIR, + get_cutlass_include_dirs, + get_spdlog_include_dir, + get_nvshmem_include_dirs, + get_nvshmem_lib_dirs, +) from .jit.core import current_compilation_context from .jit.cpp_ext import get_cuda_path, get_cuda_version @@ -76,6 +83,10 @@ def cli(ctx, download_cubin_flag): "FLASHINFER_CUDA_ARCH_LIST": current_compilation_context.TARGET_CUDA_ARCHS, "FLASHINFER_CUDA_VERSION": get_cuda_version(), "FLASHINFER_CUBINS_REPOSITORY": FLASHINFER_CUBINS_REPOSITORY, + "FLASHINFER_CUTLASS_INCLUDE_PATH": get_cutlass_include_dirs(), + "FLASHINFER_SPDLOG_INCLUDE_PATH": get_spdlog_include_dir(), + "FLASHINFER_NVSHMEM_INCLUDE_PATH": get_nvshmem_include_dirs(), + "FLASHINFER_NVSHMEM_LIBRARY_PATH": get_nvshmem_lib_dirs(), } try: env_variables["CUDA_HOME"] = get_cuda_path() diff --git a/flashinfer/comm/nvshmem.py b/flashinfer/comm/nvshmem.py index c7e170c279..1b3a5920e5 100644 --- a/flashinfer/comm/nvshmem.py +++ b/flashinfer/comm/nvshmem.py @@ -13,10 +13,14 @@ def gen_nvshmem_module() -> JitSpec: lib_dirs = jit_env.get_nvshmem_lib_dirs() + # Check new environment variable first, then fall back to old one for backward compatibility + ldflags_env = os.environ.get("FLASHINFER_NVSHMEM_LDFLAGS") or os.environ.get( + "NVSHMEM_LDFLAGS", "" + ) ldflags = ( [f"-L{lib_dir}" for lib_dir in lib_dirs] + ["-lnvshmem_device"] - + shlex.split(os.environ.get("NVSHMEM_LDFLAGS", "")) + + shlex.split(ldflags_env) ) return gen_jit_spec( diff --git a/flashinfer/jit/cpp_ext.py b/flashinfer/jit/cpp_ext.py index 693e47c26f..0df1adde32 100644 --- a/flashinfer/jit/cpp_ext.py +++ b/flashinfer/jit/cpp_ext.py @@ -104,8 +104,8 @@ def generate_ninja_build_for_op( jit_env.FLASHINFER_INCLUDE_DIR.resolve(), jit_env.FLASHINFER_CSRC_DIR.resolve(), ] - system_includes += [p.resolve() for p in jit_env.CUTLASS_INCLUDE_DIRS] - system_includes.append(jit_env.SPDLOG_INCLUDE_DIR.resolve()) + system_includes += [p.resolve() for p in jit_env.get_cutlass_include_dirs()] + system_includes.append(jit_env.get_spdlog_include_dir().resolve()) common_cflags = [] if not sysconfig.get_config_var("Py_GIL_DISABLED"): diff --git a/flashinfer/jit/env.py b/flashinfer/jit/env.py index 3e6e6696a5..318b5bddd6 100644 --- a/flashinfer/jit/env.py +++ b/flashinfer/jit/env.py @@ -74,15 +74,39 @@ def _get_workspace_dir_name() -> pathlib.Path: FLASHINFER_CSRC_DIR = _package_root / "data" / "csrc" # FLASHINFER_SRC_DIR = _package_root / "data" / "src" FLASHINFER_AOT_DIR = _package_root / "data" / "aot" -CUTLASS_INCLUDE_DIRS = [ - _package_root / "data" / "cutlass" / "include", - _package_root / "data" / "cutlass" / "tools" / "util" / "include", -] -SPDLOG_INCLUDE_DIR = _package_root / "data" / "spdlog" / "include" + + +def get_cutlass_include_dirs(): + paths = os.environ.get("FLASHINFER_CUTLASS_INCLUDE_PATH") + if paths is not None: + return [pathlib.Path(p) for p in paths.split(os.pathsep) if p] + + # Fall back to default paths + return [ + _package_root / "data" / "cutlass" / "include", + _package_root / "data" / "cutlass" / "tools" / "util" / "include", + ] + + +def get_spdlog_include_dir(): + path = os.environ.get("FLASHINFER_SPDLOG_INCLUDE_PATH") + if path is not None: + return pathlib.Path(path) + + # Fall back to default path + return _package_root / "data" / "spdlog" / "include" + + +# For backward compatibility, keep these as properties +CUTLASS_INCLUDE_DIRS = get_cutlass_include_dirs() +SPDLOG_INCLUDE_DIR = get_spdlog_include_dir() def get_nvshmem_include_dirs(): - paths = os.environ.get("NVSHMEM_INCLUDE_PATH") + # Check new environment variable first, then fall back to old one for backward compatibility + paths = os.environ.get("FLASHINFER_NVSHMEM_INCLUDE_PATH") or os.environ.get( + "NVSHMEM_INCLUDE_PATH" + ) if paths is not None: return [pathlib.Path(p) for p in paths.split(os.pathsep) if p] @@ -93,7 +117,10 @@ def get_nvshmem_include_dirs(): def get_nvshmem_lib_dirs(): - paths = os.environ.get("NVSHMEM_LIBRARY_PATH") + # Check new environment variable first, then fall back to old one for backward compatibility + paths = os.environ.get("FLASHINFER_NVSHMEM_LIBRARY_PATH") or os.environ.get( + "NVSHMEM_LIBRARY_PATH" + ) if paths is not None: return [pathlib.Path(p) for p in paths.split(os.pathsep) if p]