Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion flashinfer/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,14 @@
)
from .jit import clear_cache_dir, jit_spec_registry
from .jit.cubin_loader import FLASHINFER_CUBINS_REPOSITORY
from .jit.env import FLASHINFER_CACHE_DIR, FLASHINFER_CUBIN_DIR
from .jit.env import (
FLASHINFER_CACHE_DIR,
FLASHINFER_CUBIN_DIR,
get_cutlass_include_dirs,
get_spdlog_include_dir,
get_nvshmem_include_dirs,
get_nvshmem_lib_dirs,
)
from .jit.core import current_compilation_context
from .jit.cpp_ext import get_cuda_path, get_cuda_version

Expand Down Expand Up @@ -76,6 +83,10 @@ def cli(ctx, download_cubin_flag):
"FLASHINFER_CUDA_ARCH_LIST": current_compilation_context.TARGET_CUDA_ARCHS,
"FLASHINFER_CUDA_VERSION": get_cuda_version(),
"FLASHINFER_CUBINS_REPOSITORY": FLASHINFER_CUBINS_REPOSITORY,
"FLASHINFER_CUTLASS_INCLUDE_PATH": get_cutlass_include_dirs(),
"FLASHINFER_SPDLOG_INCLUDE_PATH": get_spdlog_include_dir(),
"FLASHINFER_NVSHMEM_INCLUDE_PATH": get_nvshmem_include_dirs(),
"FLASHINFER_NVSHMEM_LIBRARY_PATH": get_nvshmem_lib_dirs(),
}
try:
env_variables["CUDA_HOME"] = get_cuda_path()
Expand Down
6 changes: 5 additions & 1 deletion flashinfer/comm/nvshmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@

def gen_nvshmem_module() -> JitSpec:
lib_dirs = jit_env.get_nvshmem_lib_dirs()
# Check new environment variable first, then fall back to old one for backward compatibility
ldflags_env = os.environ.get("FLASHINFER_NVSHMEM_LDFLAGS") or os.environ.get(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we have documentation collecting these FLASHINFER* vars in some .md file, and briefly mention what they may affect and the order in which related vars are consulted. this could help both users and developers understand the expected behavior (which could otherwise run away after a few ppl modifying the same)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently user can run

python -m flashinfer show-config

do display these environment variables, but yes let's clearly document it.

"NVSHMEM_LDFLAGS", ""
)
ldflags = (
[f"-L{lib_dir}" for lib_dir in lib_dirs]
+ ["-lnvshmem_device"]
+ shlex.split(os.environ.get("NVSHMEM_LDFLAGS", ""))
+ shlex.split(ldflags_env)
)

return gen_jit_spec(
Expand Down
4 changes: 2 additions & 2 deletions flashinfer/jit/cpp_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def generate_ninja_build_for_op(
jit_env.FLASHINFER_INCLUDE_DIR.resolve(),
jit_env.FLASHINFER_CSRC_DIR.resolve(),
]
system_includes += [p.resolve() for p in jit_env.CUTLASS_INCLUDE_DIRS]
system_includes.append(jit_env.SPDLOG_INCLUDE_DIR.resolve())
system_includes += [p.resolve() for p in jit_env.get_cutlass_include_dirs()]
system_includes.append(jit_env.get_spdlog_include_dir().resolve())

common_cflags = []
if not sysconfig.get_config_var("Py_GIL_DISABLED"):
Expand Down
41 changes: 34 additions & 7 deletions flashinfer/jit/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,39 @@ def _get_workspace_dir_name() -> pathlib.Path:
FLASHINFER_CSRC_DIR = _package_root / "data" / "csrc"
# FLASHINFER_SRC_DIR = _package_root / "data" / "src"
FLASHINFER_AOT_DIR = _package_root / "data" / "aot"
CUTLASS_INCLUDE_DIRS = [
_package_root / "data" / "cutlass" / "include",
_package_root / "data" / "cutlass" / "tools" / "util" / "include",
]
SPDLOG_INCLUDE_DIR = _package_root / "data" / "spdlog" / "include"


def get_cutlass_include_dirs():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For performance and consistency, it's a good practice to cache the results of this getter function. Since environment variables are not expected to change during the program's execution, caching will prevent redundant lookups and path processing. You can use @functools.cache for this, which is already used elsewhere in the codebase (e.g., in cpp_ext.py).

Please apply this caching to get_spdlog_include_dir, get_nvshmem_include_dirs, and get_nvshmem_lib_dirs as well. You'll need to add import functools at the top of the file.

Suggested change
def get_cutlass_include_dirs():
@functools.cache
def get_cutlass_include_dirs():

paths = os.environ.get("FLASHINFER_CUTLASS_INCLUDE_PATH")
if paths is not None:
return [pathlib.Path(p) for p in paths.split(os.pathsep) if p]

# Fall back to default paths
return [
_package_root / "data" / "cutlass" / "include",
_package_root / "data" / "cutlass" / "tools" / "util" / "include",
]


def get_spdlog_include_dir():
path = os.environ.get("FLASHINFER_SPDLOG_INCLUDE_PATH")
if path is not None:
return pathlib.Path(path)

# Fall back to default path
return _package_root / "data" / "spdlog" / "include"


# For backward compatibility, keep these as properties
CUTLASS_INCLUDE_DIRS = get_cutlass_include_dirs()
SPDLOG_INCLUDE_DIR = get_spdlog_include_dir()


def get_nvshmem_include_dirs():
paths = os.environ.get("NVSHMEM_INCLUDE_PATH")
# Check new environment variable first, then fall back to old one for backward compatibility
paths = os.environ.get("FLASHINFER_NVSHMEM_INCLUDE_PATH") or os.environ.get(
"NVSHMEM_INCLUDE_PATH"
)
if paths is not None:
return [pathlib.Path(p) for p in paths.split(os.pathsep) if p]

Expand All @@ -93,7 +117,10 @@ def get_nvshmem_include_dirs():


def get_nvshmem_lib_dirs():
paths = os.environ.get("NVSHMEM_LIBRARY_PATH")
# Check new environment variable first, then fall back to old one for backward compatibility
paths = os.environ.get("FLASHINFER_NVSHMEM_LIBRARY_PATH") or os.environ.get(
"NVSHMEM_LIBRARY_PATH"
)
if paths is not None:
return [pathlib.Path(p) for p in paths.split(os.pathsep) if p]

Expand Down