From d0f884d2cdbc4961ecc5dfbdb77bc96cdb595a47 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 5 Oct 2025 19:14:22 -0400 Subject: [PATCH 01/41] upd --- .github/workflows/nightly-release.yml | 344 ++++++++++++++++++ flashinfer-cubin/build_backend.py | 28 ++ flashinfer-cubin/flashinfer_cubin/__init__.py | 13 + .../flashinfer_jit_cache/__init__.py | 2 + flashinfer-jit-cache/setup.py | 26 ++ flashinfer/__init__.py | 2 + scripts/build_flashinfer_jit_cache_whl.sh | 12 + scripts/update_whl_index.py | 208 ++++++++++- setup.py | 25 +- 9 files changed, 646 insertions(+), 14 deletions(-) create mode 100644 .github/workflows/nightly-release.yml diff --git a/.github/workflows/nightly-release.yml b/.github/workflows/nightly-release.yml new file mode 100644 index 0000000000..9cfdaff778 --- /dev/null +++ b/.github/workflows/nightly-release.yml @@ -0,0 +1,344 @@ +name: Nightly Release + +on: + schedule: + # Run at 00:00 UTC every day + - cron: '0 0 * * *' + workflow_dispatch: + inputs: + date_suffix: + description: 'Date suffix for dev version (YYYYMMDD, leave empty for today)' + required: false + type: string + pull_request: + # TODO: Remove this before merging - only for debugging this PR + +jobs: + setup: + runs-on: ubuntu-latest + outputs: + dev_suffix: ${{ steps.set-suffix.outputs.dev_suffix }} + release_tag: ${{ steps.set-suffix.outputs.release_tag }} + version: ${{ steps.set-suffix.outputs.version }} + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set date suffix and release tag + id: set-suffix + run: | + # Read version from version.txt + VERSION=$(cat version.txt | tr -d '[:space:]') + + # Set date suffix + if [ -n "${{ inputs.date_suffix }}" ]; then + DEV_SUFFIX="${{ inputs.date_suffix }}" + else + DEV_SUFFIX=$(date -u +%Y%m%d) + fi + + # Create release tag with version + RELEASE_TAG="nightly-v${VERSION}-${DEV_SUFFIX}" + + echo "version=${VERSION}" >> $GITHUB_OUTPUT + echo "dev_suffix=${DEV_SUFFIX}" >> $GITHUB_OUTPUT + echo "release_tag=${RELEASE_TAG}" >> $GITHUB_OUTPUT + echo "Base version: ${VERSION}" + echo "Using dev suffix: ${DEV_SUFFIX}" + echo "Release tag: ${RELEASE_TAG}" + + build-flashinfer-python: + needs: setup + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + submodules: true + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install build dependencies + run: | + python -m pip install --upgrade pip + pip install build wheel + + - name: Build flashinfer-python sdist + env: + FLASHINFER_DEV_RELEASE_SUFFIX: ${{ needs.setup.outputs.dev_suffix }} + run: | + echo "Building flashinfer-python with dev suffix: ${FLASHINFER_DEV_RELEASE_SUFFIX}" + echo "Git commit: $(git rev-parse HEAD)" + python -m build --sdist + ls -lh dist/ + + - name: Verify version and git version + run: | + tar -xzf dist/*.tar.gz -C /tmp + EXTRACTED_DIR=$(find /tmp -maxdepth 1 -name "flashinfer-python-*" -type d) + cd "$EXTRACTED_DIR" + python -c " + import sys + sys.path.insert(0, '.') + from flashinfer._build_meta import __version__, __git_version__ + print(f'đŸ“Ļ Package version: {__version__}') + print(f'🔖 Git version: {__git_version__}') + " + + - name: Upload flashinfer-python artifact + uses: actions/upload-artifact@v4 + with: + name: flashinfer-python-sdist + path: dist/*.tar.gz + retention-days: 7 + + build-flashinfer-cubin: + needs: setup + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + submodules: true + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install build dependencies + run: | + python -m pip install --upgrade pip + pip install build twine wheel + pip install setuptools>=61.0 requests filelock torch tqdm numpy apache-tvm-ffi==0.1.0b15 + + - name: Build flashinfer-cubin wheel + env: + FLASHINFER_DEV_RELEASE_SUFFIX: ${{ needs.setup.outputs.dev_suffix }} + run: | + echo "Building flashinfer-cubin with dev suffix: ${FLASHINFER_DEV_RELEASE_SUFFIX}" + echo "Git commit: $(git rev-parse HEAD)" + cd flashinfer-cubin + rm -rf dist build *.egg-info + python -m build --wheel + ls -lh dist/ + mkdir -p ../dist + cp dist/*.whl ../dist/ + + - name: Verify version and git version + run: | + python -m pip install dist/*.whl + python -c " + import flashinfer_cubin + print(f'đŸ“Ļ Package version: {flashinfer_cubin.__version__}') + print(f'🔖 Git version: {flashinfer_cubin.__git_version__}') + " + + - name: Upload flashinfer-cubin artifact + uses: actions/upload-artifact@v4 + with: + name: flashinfer-cubin-wheel + path: dist/*.whl + retention-days: 7 + + build-flashinfer-jit-cache: + needs: setup + strategy: + fail-fast: false + matrix: + cuda: ["12.8", "12.9", "13.0"] + arch: ['x86_64', 'aarch64'] + + runs-on: [self-hosted, "${{ matrix.arch == 'aarch64' && 'arm64' || matrix.arch }}"] + + steps: + - name: Display Machine Information + run: | + echo "CPU: $(nproc) cores, $(lscpu | grep 'Model name' | cut -d':' -f2 | xargs)" + echo "RAM: $(free -h | awk '/^Mem:/ {print $7 " available out of " $2}')" + echo "Disk: $(df -h / | awk 'NR==2 {print $4 " available out of " $2}')" + echo "Architecture: $(uname -m)" + + - name: Checkout code + uses: actions/checkout@v4 + with: + submodules: true + + - name: Build wheel in container + env: + DOCKER_IMAGE: ${{ matrix.arch == 'aarch64' && format('pytorch/manylinuxaarch64-builder:cuda{0}', matrix.cuda) || format('pytorch/manylinux2_28-builder:cuda{0}', matrix.cuda) }} + FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda == '12.8' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 12.0a' }} + FLASHINFER_DEV_RELEASE_SUFFIX: ${{ needs.setup.outputs.dev_suffix }} + run: | + # Extract CUDA major and minor versions + CUDA_MAJOR=$(echo "${{ matrix.cuda }}" | cut -d'.' -f1) + CUDA_MINOR=$(echo "${{ matrix.cuda }}" | cut -d'.' -f2) + export CUDA_MAJOR + export CUDA_MINOR + export CUDA_VERSION_SUFFIX="cu${CUDA_MAJOR}${CUDA_MINOR}" + + chown -R $(id -u):$(id -g) ${{ github.workspace }} + mkdir -p ${{ github.workspace }}/ci-cache + chown -R $(id -u):$(id -g) ${{ github.workspace }}/ci-cache + + # Run the build script inside the container with proper mounts + docker run --rm \ + -v ${{ github.workspace }}:/workspace \ + -v ${{ github.workspace }}/ci-cache:/ci-cache \ + -e FLASHINFER_CI_CACHE=/ci-cache \ + -e CUDA_VERSION="${{ matrix.cuda }}" \ + -e CUDA_MAJOR="$CUDA_MAJOR" \ + -e CUDA_MINOR="$CUDA_MINOR" \ + -e CUDA_VERSION_SUFFIX="$CUDA_VERSION_SUFFIX" \ + -e FLASHINFER_DEV_RELEASE_SUFFIX="${FLASHINFER_DEV_RELEASE_SUFFIX}" \ + -e ARCH="${{ matrix.arch }}" \ + -e FLASHINFER_CUDA_ARCH_LIST="${FLASHINFER_CUDA_ARCH_LIST}" \ + --user $(id -u):$(id -g) \ + -w /workspace \ + ${{ env.DOCKER_IMAGE }} \ + bash /workspace/scripts/build_flashinfer_jit_cache_whl.sh + timeout-minutes: 180 + + - name: Display wheel size + run: du -h flashinfer-jit-cache/dist/* + + - name: Create artifact name + id: artifact-name + run: | + CUDA_NO_DOT=$(echo "${{ matrix.cuda }}" | tr -d '.') + echo "name=jit-cache-cu${CUDA_NO_DOT}-${{ matrix.arch }}" >> $GITHUB_OUTPUT + + - name: Upload flashinfer-jit-cache artifact + uses: actions/upload-artifact@v4 + with: + name: ${{ steps.artifact-name.outputs.name }} + path: flashinfer-jit-cache/dist/*.whl + retention-days: 7 + + create-release: + needs: [setup, build-flashinfer-python, build-flashinfer-cubin, build-flashinfer-jit-cache] + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Create GitHub Release (empty first) + env: + GH_TOKEN: ${{ github.token }} + run: | + TAG="${{ needs.setup.outputs.release_tag }}" + + # Delete existing release and tag if they exist + if gh release view "$TAG" &>/dev/null; then + echo "Deleting existing release: $TAG" + gh release delete "$TAG" --yes --cleanup-tag + fi + + # Create new release without assets first + gh release create "$TAG" \ + --title "Nightly Release v${{ needs.setup.outputs.version }}-${{ needs.setup.outputs.dev_suffix }}" \ + --notes "Automated nightly build for version ${{ needs.setup.outputs.version }} (dev${{ needs.setup.outputs.dev_suffix }})" \ + --prerelease + + - name: Download flashinfer-python artifact + uses: actions/download-artifact@v4 + with: + name: flashinfer-python-sdist + path: dist-python/ + + - name: Upload flashinfer-python to release + env: + GH_TOKEN: ${{ github.token }} + run: | + gh release upload "${{ needs.setup.outputs.release_tag }}" dist-python/* --clobber + + - name: Download flashinfer-cubin artifact + uses: actions/download-artifact@v4 + with: + name: flashinfer-cubin-wheel + path: dist-cubin/ + + - name: Upload flashinfer-cubin to release + env: + GH_TOKEN: ${{ github.token }} + run: | + gh release upload "${{ needs.setup.outputs.release_tag }}" dist-cubin/* --clobber + + - name: Upload flashinfer-jit-cache wheels to release (one at a time to avoid OOM) + env: + GH_TOKEN: ${{ github.token }} + run: | + # Upload jit-cache wheels one at a time to avoid OOM + # Each wheel can be several GB, so we download, upload, delete, repeat + mkdir -p dist-jit-cache + + for cuda in 128 129 130; do + for arch in x86_64 aarch64; do + ARTIFACT_NAME="jit-cache-cu${cuda}-${arch}" + echo "Processing ${ARTIFACT_NAME}..." + + # Download this specific artifact + gh run download ${{ github.run_id }} -n "${ARTIFACT_NAME}" -D dist-jit-cache/ || { + echo "Warning: Failed to download ${ARTIFACT_NAME}, skipping..." + continue + } + + # Upload to release + if [ -n "$(ls -A dist-jit-cache/)" ]; then + gh release upload "${{ needs.setup.outputs.release_tag }}" dist-jit-cache/* --clobber + echo "✅ Uploaded ${ARTIFACT_NAME}" + fi + + # Clean up to save disk space before next iteration + rm -rf dist-jit-cache/* + done + done + + update-wheel-index: + needs: [setup, create-release] + runs-on: ubuntu-latest + steps: + - name: Checkout flashinfer repo + uses: actions/checkout@v4 + + - name: Download all artifacts + uses: actions/download-artifact@v4 + with: + path: artifacts/ + + - name: Collect wheels + run: | + mkdir -p dist + find artifacts/ -name "*.whl" -exec cp {} dist/ \; + ls -lh dist/ + + - name: Clone wheel index + run: git clone https://oauth2:${WHL_TOKEN}@github.com/flashinfer-ai/whl.git flashinfer-whl + env: + WHL_TOKEN: ${{ secrets.WHL_TOKEN }} + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Update wheel index + run: | + python3 scripts/update_whl_index.py \ + --dist-dir dist \ + --output-dir flashinfer-whl \ + --release-tag "${{ needs.setup.outputs.release_tag }}" + + - name: Push wheel index + run: | + cd flashinfer-whl + git config --local user.name "github-actions[bot]" + git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com" + git add -A + git commit -m "update whl for nightly ${{ needs.setup.outputs.dev_suffix }}" + git push diff --git a/flashinfer-cubin/build_backend.py b/flashinfer-cubin/build_backend.py index 8815390ad5..18dc48f53f 100644 --- a/flashinfer-cubin/build_backend.py +++ b/flashinfer-cubin/build_backend.py @@ -51,6 +51,25 @@ def _download_cubins(): os.environ.pop("FLASHINFER_CUBIN_DIR", None) +def _get_git_version(): + """Get git commit hash.""" + import subprocess + + try: + git_version = ( + subprocess.check_output( + ["git", "rev-parse", "HEAD"], + cwd=Path(__file__).parent.parent, + stderr=subprocess.DEVNULL, + ) + .decode("ascii") + .strip() + ) + return git_version + except Exception: + return "unknown" + + def _create_build_metadata(): """Create build metadata file with version information.""" version_file = Path(__file__).parent.parent / "version.txt" @@ -60,6 +79,14 @@ def _create_build_metadata(): else: version = "0.0.0+unknown" + # Add dev suffix if specified + dev_suffix = os.environ.get("FLASHINFER_DEV_RELEASE_SUFFIX", "") + if dev_suffix: + version = f"{version}.dev{dev_suffix}" + + # Get git version + git_version = _get_git_version() + # Create build metadata in the source tree package_dir = Path(__file__).parent / "flashinfer_cubin" build_meta_file = package_dir / "_build_meta.py" @@ -67,6 +94,7 @@ def _create_build_metadata(): with open(build_meta_file, "w") as f: f.write('"""Build metadata for flashinfer-cubin package."""\n') f.write(f'__version__ = "{version}"\n') + f.write(f'__git_version__ = "{git_version}"\n') print(f"Created build metadata file with version {version}") return version diff --git a/flashinfer-cubin/flashinfer_cubin/__init__.py b/flashinfer-cubin/flashinfer_cubin/__init__.py index b28fb09da5..abfe816282 100644 --- a/flashinfer-cubin/flashinfer_cubin/__init__.py +++ b/flashinfer-cubin/flashinfer_cubin/__init__.py @@ -63,5 +63,18 @@ def _get_version(): return "0.0.0" +def _get_git_version(): + # First try to read from build metadata (for wheel distributions) + try: + from . import _build_meta + + return _build_meta.__git_version__ + except (ImportError, AttributeError): + pass + + return "unknown" + + __version__ = _get_version() +__git_version__ = _get_git_version() __all__ = ["get_cubin_dir", "list_cubins", "get_cubin_path", "CUBIN_DIR"] diff --git a/flashinfer-jit-cache/flashinfer_jit_cache/__init__.py b/flashinfer-jit-cache/flashinfer_jit_cache/__init__.py index 986232c63a..3be467e99a 100644 --- a/flashinfer-jit-cache/flashinfer_jit_cache/__init__.py +++ b/flashinfer-jit-cache/flashinfer_jit_cache/__init__.py @@ -29,8 +29,10 @@ def get_jit_cache_dir() -> str: try: from ._build_meta import __version__ as __version__ + from ._build_meta import __git_version__ as __git_version__ except ModuleNotFoundError: __version__ = "0.0.0+unknown" + __git_version__ = "unknown" __all__ = [ diff --git a/flashinfer-jit-cache/setup.py b/flashinfer-jit-cache/setup.py index e65b216cce..2c0a0c7dbe 100644 --- a/flashinfer-jit-cache/setup.py +++ b/flashinfer-jit-cache/setup.py @@ -16,6 +16,11 @@ def get_version(): else: version = "0.0.0" + # Add dev suffix if specified + dev_suffix = os.environ.get("FLASHINFER_DEV_RELEASE_SUFFIX", "") + if dev_suffix: + version = f"{version}.dev{dev_suffix}" + # Append CUDA version suffix if available cuda_suffix = os.environ.get("CUDA_VERSION_SUFFIX", "") if cuda_suffix: @@ -25,13 +30,34 @@ def get_version(): return version +def get_git_version(): + """Get git commit hash.""" + import subprocess + + try: + git_version = ( + subprocess.check_output( + ["git", "rev-parse", "HEAD"], + cwd=Path(__file__).parent.parent, + stderr=subprocess.DEVNULL, + ) + .decode("ascii") + .strip() + ) + return git_version + except Exception: + return "unknown" + + def generate_build_meta(): """Generate build metadata file.""" build_meta_file = Path(__file__).parent / "flashinfer_jit_cache" / "_build_meta.py" version = get_version() + git_version = get_git_version() with open(build_meta_file, "w") as f: f.write('"""Build metadata for flashinfer-jit-cache package."""\n') f.write(f'__version__ = "{version}"\n') + f.write(f'__git_version__ = "{git_version}"\n') class PlatformSpecificBdistWheel(bdist_wheel): diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index ccfae46ee7..ec08322649 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -18,8 +18,10 @@ try: from ._build_meta import __version__ as __version__ + from ._build_meta import __git_version__ as __git_version__ except ModuleNotFoundError: __version__ = "0.0.0+unknown" + __git_version__ = "unknown" from . import jit as jit diff --git a/scripts/build_flashinfer_jit_cache_whl.sh b/scripts/build_flashinfer_jit_cache_whl.sh index 30495ff377..fd1a313e14 100755 --- a/scripts/build_flashinfer_jit_cache_whl.sh +++ b/scripts/build_flashinfer_jit_cache_whl.sh @@ -27,8 +27,10 @@ echo "CUDA Major: ${CUDA_MAJOR}" echo "CUDA Minor: ${CUDA_MINOR}" echo "CUDA Version Suffix: ${CUDA_VERSION_SUFFIX}" echo "CUDA Architectures: ${FLASHINFER_CUDA_ARCH_LIST}" +echo "Dev Release Suffix: ${FLASHINFER_DEV_RELEASE_SUFFIX}" echo "MAX_JOBS: ${MAX_JOBS}" echo "Python Version: $(python3 --version)" +echo "Git commit: $(git rev-parse HEAD 2>/dev/null || echo 'unknown')" echo "Working directory: $(pwd)" echo "" @@ -62,6 +64,16 @@ echo "" echo "Built wheels:" ls -lh dist/ +# Verify version and git version +echo "" +echo "Verifying version and git version..." +pip install dist/*.whl +python -c " +import flashinfer_jit_cache +print(f'đŸ“Ļ Package version: {flashinfer_jit_cache.__version__}') +print(f'🔖 Git version: {flashinfer_jit_cache.__git_version__}') +" + # Copy wheels to output directory if specified if [ -n "${OUTPUT_DIR}" ]; then echo "" diff --git a/scripts/update_whl_index.py b/scripts/update_whl_index.py index 4902d0cc7f..01b6187996 100644 --- a/scripts/update_whl_index.py +++ b/scripts/update_whl_index.py @@ -1,17 +1,199 @@ +""" +Update wheel index for flashinfer packages. + +This script generates PEP 503 compatible simple repository index pages for: +- flashinfer-python (no CUDA suffix in version) +- flashinfer-cubin (no CUDA suffix in version) +- flashinfer-jit-cache (has CUDA suffix like +cu130) + +The index is organized by CUDA version for jit-cache, and flat for others. +""" + import hashlib import pathlib import re +import argparse +import sys +from typing import Optional + + +def get_cuda_version(wheel_name: str) -> Optional[str]: + """Extract CUDA version from wheel filename.""" + # Match patterns like +cu128, +cu129, +cu130 + match = re.search(r"\+cu(\d+)", wheel_name) + if match: + return match.group(1) + return None + + +def get_package_info(wheel_path: pathlib.Path) -> Optional[dict]: + """Extract package information from wheel filename.""" + wheel_name = wheel_path.name + + # Try flashinfer-python pattern + match = re.match(r"flashinfer_python-([0-9.]+(?:\.dev\d+)?)-", wheel_name) + if match: + version = match.group(1) + return { + "package": "flashinfer-python", + "version": version, + "cuda": None, + } + + # Try flashinfer-cubin pattern + match = re.match(r"flashinfer_cubin-([0-9.]+(?:\.dev\d+)?)-", wheel_name) + if match: + version = match.group(1) + return { + "package": "flashinfer-cubin", + "version": version, + "cuda": None, + } + + # Try flashinfer-jit-cache pattern (has CUDA suffix in version) + match = re.match(r"flashinfer_jit_cache-([0-9.]+(?:\.dev\d+)?\+cu\d+)-", wheel_name) + if match: + version = match.group(1) + cuda_ver = get_cuda_version(wheel_name) + return { + "package": "flashinfer-jit-cache", + "version": version, + "cuda": cuda_ver, + } + + return None + + +def compute_sha256(file_path: pathlib.Path) -> str: + """Compute SHA256 hash of a file.""" + with open(file_path, "rb") as f: + return hashlib.sha256(f.read()).hexdigest() + + +def update_index( + dist_dir: str = "dist", + output_dir: str = "whl", + base_url: str = "https://github.com/flashinfer-ai/flashinfer/releases/download", + release_tag: Optional[str] = None, +): + """ + Update wheel index from dist directory. + + Args: + dist_dir: Directory containing wheel files + output_dir: Output directory for index files + base_url: Base URL for wheel downloads + release_tag: GitHub release tag (e.g., 'nightly' or 'v0.3.1') + """ + dist_path = pathlib.Path(dist_dir) + if not dist_path.exists(): + print(f"Error: dist directory '{dist_dir}' does not exist") + sys.exit(1) + + wheels = sorted(dist_path.glob("*.whl")) + if not wheels: + print(f"No wheel files found in '{dist_dir}'") + sys.exit(1) + + print(f"Found {len(wheels)} wheel file(s)") + + for wheel_path in wheels: + print(f"\nProcessing: {wheel_path.name}") + + # Extract package information + info = get_package_info(wheel_path) + if not info: + print(" âš ī¸ Skipping: Could not parse wheel filename") + continue + + # Compute SHA256 + sha256 = compute_sha256(wheel_path) + + # Determine index directory + package = info["package"] + cuda = info["cuda"] + + if cuda: + # CUDA-specific index for jit-cache: whl/cu130/flashinfer-jit-cache/ + index_dir = pathlib.Path(output_dir) / f"cu{cuda}" / package + else: + # No CUDA version for python/cubin: whl/flashinfer-python/ + index_dir = pathlib.Path(output_dir) / package + + index_dir.mkdir(parents=True, exist_ok=True) + + # Construct download URL + tag = release_tag or f"v{info['version'].split('+')[0].split('.dev')[0]}" + download_url = f"{base_url}/{tag}/{wheel_path.name}#sha256={sha256}" + + # Update index.html + index_file = index_dir / "index.html" + + # Read existing content to avoid duplicates + existing_links = set() + if index_file.exists(): + with index_file.open("r") as f: + existing_links = set(f.readlines()) + else: + # Create new index file with HTML header + with index_file.open("w") as f: + f.write("\n") + f.write("\n") + f.write(f"Links for {package}\n") + f.write("\n") + f.write(f"

Links for {package}

\n") + print(f" 📝 Created new index file: {index_dir}/index.html") + + # Create new link + new_link = f'{wheel_path.name}
\n' + + if new_link in existing_links: + print(f" â„šī¸ Already in index: {index_dir}/index.html") + else: + with index_file.open("a") as f: + f.write(new_link) + print(f" ✅ Added to index: {index_dir}/index.html") + + print(f" đŸ“Ļ Package: {package}") + print(f" 🔖 Version: {info['version']}") + if cuda: + print(f" 🎮 CUDA: cu{cuda}") + print(f" 📍 URL: {download_url}") + + +def main(): + parser = argparse.ArgumentParser( + description="Update wheel index for flashinfer packages" + ) + parser.add_argument( + "--dist-dir", + default="dist", + help="Directory containing wheel files (default: dist)", + ) + parser.add_argument( + "--output-dir", + default="whl", + help="Output directory for index files (default: whl)", + ) + parser.add_argument( + "--base-url", + default="https://github.com/flashinfer-ai/flashinfer/releases/download", + help="Base URL for wheel downloads", + ) + parser.add_argument( + "--release-tag", + help="GitHub release tag (e.g., 'nightly' or 'v0.3.1'). If not specified, will be derived from version.", + ) + + args = parser.parse_args() + + update_index( + dist_dir=args.dist_dir, + output_dir=args.output_dir, + base_url=args.base_url, + release_tag=args.release_tag, + ) + -for path in sorted(pathlib.Path("dist").glob("*.whl")): - with open(path, "rb") as f: - sha256 = hashlib.sha256(f.read()).hexdigest() - ver, cu, torch = re.findall( - r"flashinfer_python-([0-9.]+(?:\.post[0-9]+)?)\+cu(\d+)torch([0-9.]+)-", - path.name, - )[0] - index_dir = pathlib.Path(f"flashinfer-whl/cu{cu}/torch{torch}/flashinfer-python") - index_dir.mkdir(exist_ok=True) - base_url = "https://github.com/flashinfer-ai/flashinfer/releases/download" - full_url = f"{base_url}/v{ver}/{path.name}#sha256={sha256}" - with (index_dir / "index.html").open("a") as f: - f.write(f'{path.name}
\n') +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index 6c663c1cb6..8fdc0a76a5 100755 --- a/setup.py +++ b/setup.py @@ -30,12 +30,35 @@ def write_if_different(path: Path, content: str) -> None: def get_version(): + import os + package_version = (root / "version.txt").read_text().strip() - return f"{package_version}" + dev_suffix = os.environ.get("FLASHINFER_DEV_RELEASE_SUFFIX", "") + if dev_suffix: + package_version = f"{package_version}.dev{dev_suffix}" + return package_version + + +def get_git_version(): + """Get git commit hash.""" + import subprocess + + try: + git_version = ( + subprocess.check_output( + ["git", "rev-parse", "HEAD"], cwd=root, stderr=subprocess.DEVNULL + ) + .decode("ascii") + .strip() + ) + return git_version + except Exception: + return "unknown" def generate_build_meta() -> None: build_meta_str = f"__version__ = {get_version()!r}\n" + build_meta_str += f"__git_version__ = {get_git_version()!r}\n" write_if_different(root / "flashinfer" / "_build_meta.py", build_meta_str) From a0b9b3a4e889731e4e6c7ab3a77944de10601209 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 5 Oct 2025 19:25:19 -0400 Subject: [PATCH 02/41] upd --- flashinfer-cubin/build_backend.py | 2 ++ flashinfer-jit-cache/build_backend.py | 22 ++++++++++++++++++++++ pyproject.toml | 2 +- 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/flashinfer-cubin/build_backend.py b/flashinfer-cubin/build_backend.py index 18dc48f53f..6d103dd792 100644 --- a/flashinfer-cubin/build_backend.py +++ b/flashinfer-cubin/build_backend.py @@ -17,9 +17,11 @@ if version_file.exists(): with open(version_file, "r") as f: version = f.read().strip() +git_version = _get_git_version() with open(build_meta_file, "w") as f: f.write('"""Build metadata for flashinfer package."""\n') f.write(f'__version__ = "{version}"\n') + f.write(f'__git_version__ = "{git_version}"\n') def _download_cubins(): diff --git a/flashinfer-jit-cache/build_backend.py b/flashinfer-jit-cache/build_backend.py index 7e0d9f4aaa..c8d736c602 100644 --- a/flashinfer-jit-cache/build_backend.py +++ b/flashinfer-jit-cache/build_backend.py @@ -22,15 +22,37 @@ # Add parent directory to path to import flashinfer modules sys.path.insert(0, str(Path(__file__).parent.parent)) + +def _get_git_version(): + """Get git commit hash.""" + import subprocess + + try: + git_version = ( + subprocess.check_output( + ["git", "rev-parse", "HEAD"], + cwd=Path(__file__).parent.parent, + stderr=subprocess.DEVNULL, + ) + .decode("ascii") + .strip() + ) + return git_version + except Exception: + return "unknown" + + # add flashinfer._build_meta, always override to ensure version is up-to-date build_meta_file = Path(__file__).parent.parent / "flashinfer" / "_build_meta.py" version_file = Path(__file__).parent.parent / "version.txt" if version_file.exists(): with open(version_file, "r") as f: version = f.read().strip() +git_version = _get_git_version() with open(build_meta_file, "w") as f: f.write('"""Build metadata for flashinfer package."""\n') f.write(f'__version__ = "{version}"\n') + f.write(f'__git_version__ = "{git_version}"\n') def get_version(): diff --git a/pyproject.toml b/pyproject.toml index ef13baa13f..bbf5314d7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dynamic = ["dependencies", "version"] license-files = ["LICENSE", "licenses/*"] [build-system] -requires = ["setuptools>=77", "packaging>=24"] +requires = ["setuptools>=77", "packaging>=24", "filelock"] build-backend = "custom_backend" backend-path = ["."] From 4e9bc162bb4aaf45fd6bd0516950470173cfb523 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 5 Oct 2025 19:30:46 -0400 Subject: [PATCH 03/41] upd --- .github/workflows/nightly-release.yml | 6 +---- flashinfer-cubin/build_backend.py | 39 ++++++++++++++------------- pyproject.toml | 2 +- 3 files changed, 22 insertions(+), 25 deletions(-) diff --git a/.github/workflows/nightly-release.yml b/.github/workflows/nightly-release.yml index 9cfdaff778..67ae0b6fa1 100644 --- a/.github/workflows/nightly-release.yml +++ b/.github/workflows/nightly-release.yml @@ -77,12 +77,8 @@ jobs: - name: Verify version and git version run: | - tar -xzf dist/*.tar.gz -C /tmp - EXTRACTED_DIR=$(find /tmp -maxdepth 1 -name "flashinfer-python-*" -type d) - cd "$EXTRACTED_DIR" + pip install dist/*.tar.gz python -c " - import sys - sys.path.insert(0, '.') from flashinfer._build_meta import __version__, __git_version__ print(f'đŸ“Ļ Package version: {__version__}') print(f'🔖 Git version: {__git_version__}') diff --git a/flashinfer-cubin/build_backend.py b/flashinfer-cubin/build_backend.py index 6d103dd792..89c531af99 100644 --- a/flashinfer-cubin/build_backend.py +++ b/flashinfer-cubin/build_backend.py @@ -11,6 +11,26 @@ # Add parent directory to path to import artifacts module sys.path.insert(0, str(Path(__file__).parent.parent)) + +def _get_git_version(): + """Get git commit hash.""" + import subprocess + + try: + git_version = ( + subprocess.check_output( + ["git", "rev-parse", "HEAD"], + cwd=Path(__file__).parent.parent, + stderr=subprocess.DEVNULL, + ) + .decode("ascii") + .strip() + ) + return git_version + except Exception: + return "unknown" + + # add flashinfer._build_meta, always override to ensure version is up-to-date build_meta_file = Path(__file__).parent.parent / "flashinfer" / "_build_meta.py" version_file = Path(__file__).parent.parent / "version.txt" @@ -53,25 +73,6 @@ def _download_cubins(): os.environ.pop("FLASHINFER_CUBIN_DIR", None) -def _get_git_version(): - """Get git commit hash.""" - import subprocess - - try: - git_version = ( - subprocess.check_output( - ["git", "rev-parse", "HEAD"], - cwd=Path(__file__).parent.parent, - stderr=subprocess.DEVNULL, - ) - .decode("ascii") - .strip() - ) - return git_version - except Exception: - return "unknown" - - def _create_build_metadata(): """Create build metadata file with version information.""" version_file = Path(__file__).parent.parent / "version.txt" diff --git a/pyproject.toml b/pyproject.toml index bbf5314d7b..ef13baa13f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dynamic = ["dependencies", "version"] license-files = ["LICENSE", "licenses/*"] [build-system] -requires = ["setuptools>=77", "packaging>=24", "filelock"] +requires = ["setuptools>=77", "packaging>=24"] build-backend = "custom_backend" backend-path = ["."] From 449eaf02e8ba9f7e5f6da41c587146298e0365c4 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 5 Oct 2025 19:36:02 -0400 Subject: [PATCH 04/41] upd --- .github/workflows/nightly-release.yml | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/.github/workflows/nightly-release.yml b/.github/workflows/nightly-release.yml index 67ae0b6fa1..b167a3a525 100644 --- a/.github/workflows/nightly-release.yml +++ b/.github/workflows/nightly-release.yml @@ -75,15 +75,6 @@ jobs: python -m build --sdist ls -lh dist/ - - name: Verify version and git version - run: | - pip install dist/*.tar.gz - python -c " - from flashinfer._build_meta import __version__, __git_version__ - print(f'đŸ“Ļ Package version: {__version__}') - print(f'🔖 Git version: {__git_version__}') - " - - name: Upload flashinfer-python artifact uses: actions/upload-artifact@v4 with: @@ -124,15 +115,6 @@ jobs: mkdir -p ../dist cp dist/*.whl ../dist/ - - name: Verify version and git version - run: | - python -m pip install dist/*.whl - python -c " - import flashinfer_cubin - print(f'đŸ“Ļ Package version: {flashinfer_cubin.__version__}') - print(f'🔖 Git version: {flashinfer_cubin.__git_version__}') - " - - name: Upload flashinfer-cubin artifact uses: actions/upload-artifact@v4 with: From 9e70e3729cb0eb69aa8594a16fdb619602a8c813 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 5 Oct 2025 19:41:25 -0400 Subject: [PATCH 05/41] upd --- flashinfer/aot.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/flashinfer/aot.py b/flashinfer/aot.py index 5062e7c576..81fc8e5eb8 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -525,6 +525,9 @@ def gen_all_modules( # Add cuDNN FMHA module jit_specs.append(gen_cudnn_fmha_module()) + # NOTE(Zihao): just for debugging, remove later + jit_specs = [gen_spdlog_module()] + # dedup names = set() ret: List[JitSpec] = [] From db75317029f946174d9a6ffb4a2e2a9e8cb6f324 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 5 Oct 2025 19:49:45 -0400 Subject: [PATCH 06/41] upd --- .github/workflows/nightly-release.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/nightly-release.yml b/.github/workflows/nightly-release.yml index b167a3a525..9296881d1d 100644 --- a/.github/workflows/nightly-release.yml +++ b/.github/workflows/nightly-release.yml @@ -13,6 +13,9 @@ on: pull_request: # TODO: Remove this before merging - only for debugging this PR +permissions: + contents: write + jobs: setup: runs-on: ubuntu-latest From 8ab269d2930cae9022ad0b7e5b9ef7faf85176c0 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 5 Oct 2025 20:01:06 -0400 Subject: [PATCH 07/41] upd --- .github/workflows/nightly-release.yml | 1 + flashinfer-cubin/download_cubins.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/nightly-release.yml b/.github/workflows/nightly-release.yml index 9296881d1d..fdece945f5 100644 --- a/.github/workflows/nightly-release.yml +++ b/.github/workflows/nightly-release.yml @@ -15,6 +15,7 @@ on: permissions: contents: write + id-token: write jobs: setup: diff --git a/flashinfer-cubin/download_cubins.py b/flashinfer-cubin/download_cubins.py index 2f2c847bbe..cbfd0d4f40 100644 --- a/flashinfer-cubin/download_cubins.py +++ b/flashinfer-cubin/download_cubins.py @@ -23,7 +23,7 @@ # Add parent directory to path to import flashinfer modules sys.path.insert(0, str(Path(__file__).parent.parent)) -from flashinfer.artifacts import download_artifacts +# from flashinfer.artifacts import download_artifacts from flashinfer.jit.cubin_loader import FLASHINFER_CUBINS_REPOSITORY @@ -69,7 +69,8 @@ def main(): # Use the existing download_artifacts function try: - download_artifacts() + # NOTE(Zihao): just for debugging, remove later + # download_artifacts() print("Download complete!") except Exception as e: print(f"Download failed: {e}") From 030567eb320f0055dd908b0d70ed3cbd1889cf60 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 5 Oct 2025 20:05:25 -0400 Subject: [PATCH 08/41] remove unused files --- flashinfer-cubin/download_cubins.py | 81 ----------------------------- 1 file changed, 81 deletions(-) delete mode 100644 flashinfer-cubin/download_cubins.py diff --git a/flashinfer-cubin/download_cubins.py b/flashinfer-cubin/download_cubins.py deleted file mode 100644 index cbfd0d4f40..0000000000 --- a/flashinfer-cubin/download_cubins.py +++ /dev/null @@ -1,81 +0,0 @@ -#!/usr/bin/env python3 -""" -Copyright (c) 2025 by FlashInfer team. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import os -import sys -import argparse -from pathlib import Path - -# Add parent directory to path to import flashinfer modules -sys.path.insert(0, str(Path(__file__).parent.parent)) - -# from flashinfer.artifacts import download_artifacts -from flashinfer.jit.cubin_loader import FLASHINFER_CUBINS_REPOSITORY - - -def main(): - parser = argparse.ArgumentParser( - description="Download FlashInfer cubins from artifactory" - ) - parser.add_argument( - "--output-dir", - "-o", - type=str, - default="flashinfer_cubin/cubins", - help="Output directory for cubins (default: flashinfer_cubin/cubins)", - ) - parser.add_argument( - "--threads", - "-t", - type=int, - default=4, - help="Number of download threads (default: 4)", - ) - parser.add_argument( - "--repository", - "-r", - type=str, - default=None, - help="Override the cubins repository URL", - ) - - args = parser.parse_args() - - # Set environment variables to control download behavior - if args.repository: - os.environ["FLASHINFER_CUBINS_REPOSITORY"] = args.repository - - os.environ["FLASHINFER_CUBIN_DIR"] = str(Path(args.output_dir).absolute()) - os.environ["FLASHINFER_CUBIN_DOWNLOAD_THREADS"] = str(args.threads) - - print(f"Downloading cubins to {args.output_dir}") - print( - f"Repository: {os.environ.get('FLASHINFER_CUBINS_REPOSITORY', FLASHINFER_CUBINS_REPOSITORY)}" - ) - - # Use the existing download_artifacts function - try: - # NOTE(Zihao): just for debugging, remove later - # download_artifacts() - print("Download complete!") - except Exception as e: - print(f"Download failed: {e}") - sys.exit(1) - - -if __name__ == "__main__": - main() From c58497043ee1629e2dedba8868883691bce02d1e Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 5 Oct 2025 20:11:03 -0400 Subject: [PATCH 09/41] upd --- .github/workflows/nightly-release.yml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/workflows/nightly-release.yml b/.github/workflows/nightly-release.yml index fdece945f5..b167a3a525 100644 --- a/.github/workflows/nightly-release.yml +++ b/.github/workflows/nightly-release.yml @@ -13,10 +13,6 @@ on: pull_request: # TODO: Remove this before merging - only for debugging this PR -permissions: - contents: write - id-token: write - jobs: setup: runs-on: ubuntu-latest From 0925678532273d5e2763e9cee027951ac9c96a2c Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 5 Oct 2025 23:27:04 -0400 Subject: [PATCH 10/41] upd --- custom_backend.py => build_backend.py | 44 +++++++++ flashinfer-cubin/build_backend.py | 23 +---- flashinfer-jit-cache/build_backend.py | 75 +++++++++----- flashinfer-jit-cache/setup.py | 135 -------------------------- flashinfer/build_utils.py | 46 +++++++++ pyproject.toml | 2 +- scripts/update_whl_index.py | 34 ++++--- setup.py | 90 ----------------- 8 files changed, 166 insertions(+), 283 deletions(-) rename custom_backend.py => build_backend.py (59%) delete mode 100644 flashinfer-jit-cache/setup.py create mode 100644 flashinfer/build_utils.py delete mode 100755 setup.py diff --git a/custom_backend.py b/build_backend.py similarity index 59% rename from custom_backend.py rename to build_backend.py index 0484d714fd..96465cf833 100644 --- a/custom_backend.py +++ b/build_backend.py @@ -1,12 +1,51 @@ +""" +Copyright (c) 2023 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os import shutil from pathlib import Path from setuptools import build_meta as orig +from flashinfer.build_utils import get_git_version _root = Path(__file__).parent.resolve() _data_dir = _root / "flashinfer" / "data" +def write_if_different(path: Path, content: str) -> None: + if path.exists() and path.read_text() == content: + return + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(content) + + +def get_version(): + package_version = (_root / "version.txt").read_text().strip() + dev_suffix = os.environ.get("FLASHINFER_DEV_RELEASE_SUFFIX", "") + if dev_suffix: + package_version = f"{package_version}.dev{dev_suffix}" + return package_version + + +def generate_build_meta() -> None: + build_meta_str = f"__version__ = {get_version()!r}\n" + build_meta_str += f"__git_version__ = {get_git_version(cwd=_root)!r}\n" + write_if_different(_root / "flashinfer" / "_build_meta.py", build_meta_str) + + def _create_data_dir(): _data_dir.mkdir(parents=True, exist_ok=True) @@ -43,16 +82,21 @@ def _prepare_for_sdist(): def get_requires_for_build_wheel(config_settings=None): + generate_build_meta() _prepare_for_wheel() + return [] def get_requires_for_build_sdist(config_settings=None): + generate_build_meta() _prepare_for_sdist() return [] def get_requires_for_build_editable(config_settings=None): + generate_build_meta() _prepare_for_editable() + return [] def prepare_metadata_for_build_wheel(metadata_directory, config_settings=None): diff --git a/flashinfer-cubin/build_backend.py b/flashinfer-cubin/build_backend.py index 89c531af99..d08b6db97c 100644 --- a/flashinfer-cubin/build_backend.py +++ b/flashinfer-cubin/build_backend.py @@ -11,24 +11,7 @@ # Add parent directory to path to import artifacts module sys.path.insert(0, str(Path(__file__).parent.parent)) - -def _get_git_version(): - """Get git commit hash.""" - import subprocess - - try: - git_version = ( - subprocess.check_output( - ["git", "rev-parse", "HEAD"], - cwd=Path(__file__).parent.parent, - stderr=subprocess.DEVNULL, - ) - .decode("ascii") - .strip() - ) - return git_version - except Exception: - return "unknown" +from flashinfer.build_utils import get_git_version # add flashinfer._build_meta, always override to ensure version is up-to-date @@ -37,7 +20,7 @@ def _get_git_version(): if version_file.exists(): with open(version_file, "r") as f: version = f.read().strip() -git_version = _get_git_version() +git_version = get_git_version(cwd=Path(__file__).parent.parent) with open(build_meta_file, "w") as f: f.write('"""Build metadata for flashinfer package."""\n') f.write(f'__version__ = "{version}"\n') @@ -88,7 +71,7 @@ def _create_build_metadata(): version = f"{version}.dev{dev_suffix}" # Get git version - git_version = _get_git_version() + git_version = get_git_version(cwd=Path(__file__).parent.parent) # Create build metadata in the source tree package_dir = Path(__file__).parent / "flashinfer_cubin" diff --git a/flashinfer-jit-cache/build_backend.py b/flashinfer-jit-cache/build_backend.py index c8d736c602..3ee611a8d4 100644 --- a/flashinfer-jit-cache/build_backend.py +++ b/flashinfer-jit-cache/build_backend.py @@ -16,30 +16,15 @@ import sys import os +import platform from pathlib import Path from setuptools import build_meta as _orig +from wheel.bdist_wheel import bdist_wheel # Add parent directory to path to import flashinfer modules sys.path.insert(0, str(Path(__file__).parent.parent)) - -def _get_git_version(): - """Get git commit hash.""" - import subprocess - - try: - git_version = ( - subprocess.check_output( - ["git", "rev-parse", "HEAD"], - cwd=Path(__file__).parent.parent, - stderr=subprocess.DEVNULL, - ) - .decode("ascii") - .strip() - ) - return git_version - except Exception: - return "unknown" +from flashinfer.build_utils import get_git_version # add flashinfer._build_meta, always override to ensure version is up-to-date @@ -48,7 +33,7 @@ def _get_git_version(): if version_file.exists(): with open(version_file, "r") as f: version = f.read().strip() -git_version = _get_git_version() +git_version = get_git_version(cwd=Path(__file__).parent.parent) with open(build_meta_file, "w") as f: f.write('"""Build metadata for flashinfer package."""\n') f.write(f'__version__ = "{version}"\n') @@ -130,15 +115,61 @@ def _prepare_build(): print(f"Created build metadata file with version {version}") +class PlatformSpecificBdistWheel(bdist_wheel): + """Custom wheel builder that uses py_limited_api for cp39+.""" + + def finalize_options(self): + super().finalize_options() + # Force platform-specific wheel (not pure Python) + self.root_is_pure = False + # Use py_limited_api for cp39 (Python 3.9+) + self.py_limited_api = "cp39" + + def get_tag(self): + # Use py_limited_api tags + python_tag = "cp39" + abi_tag = "abi3" # Stable ABI tag + + # Get platform tag + machine = platform.machine() + if platform.system() == "Linux": + # Use manylinux_2_28 as specified + if machine == "x86_64": + plat_tag = "manylinux_2_28_x86_64" + elif machine == "aarch64": + plat_tag = "manylinux_2_28_aarch64" + else: + plat_tag = f"linux_{machine}" + else: + # For non-Linux platforms, use the default + import distutils.util + + plat_tag = distutils.util.get_platform().replace("-", "_").replace(".", "_") + + return python_tag, abi_tag, plat_tag + + def build_wheel(wheel_directory, config_settings=None, metadata_directory=None): """Build wheel with custom AOT module compilation.""" print("Building flashinfer-jit-cache wheel...") _prepare_build() - # Now build the wheel using setuptools - # The setup.py file will handle the platform-specific wheel naming - result = _orig.build_wheel(wheel_directory, config_settings, metadata_directory) + # Inject custom bdist_wheel class + import setuptools + + original_cmdclass = getattr(setuptools, "_GLOBAL_CMDCLASS", {}) + setuptools._GLOBAL_CMDCLASS = { + **original_cmdclass, + "bdist_wheel": PlatformSpecificBdistWheel, + } + + try: + # Now build the wheel using setuptools + result = _orig.build_wheel(wheel_directory, config_settings, metadata_directory) + finally: + # Restore original cmdclass + setuptools._GLOBAL_CMDCLASS = original_cmdclass return result diff --git a/flashinfer-jit-cache/setup.py b/flashinfer-jit-cache/setup.py deleted file mode 100644 index 2c0a0c7dbe..0000000000 --- a/flashinfer-jit-cache/setup.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Setup script for flashinfer-jit-cache package.""" - -import os -import platform -from pathlib import Path -from setuptools import setup, find_packages -from wheel.bdist_wheel import bdist_wheel - - -def get_version(): - """Get version from version.txt file.""" - version_file = Path(__file__).parent.parent / "version.txt" - if version_file.exists(): - with open(version_file, "r") as f: - version = f.read().strip() - else: - version = "0.0.0" - - # Add dev suffix if specified - dev_suffix = os.environ.get("FLASHINFER_DEV_RELEASE_SUFFIX", "") - if dev_suffix: - version = f"{version}.dev{dev_suffix}" - - # Append CUDA version suffix if available - cuda_suffix = os.environ.get("CUDA_VERSION_SUFFIX", "") - if cuda_suffix: - # Use + to create a local version identifier that will appear in wheel name - version = f"{version}+{cuda_suffix}" - - return version - - -def get_git_version(): - """Get git commit hash.""" - import subprocess - - try: - git_version = ( - subprocess.check_output( - ["git", "rev-parse", "HEAD"], - cwd=Path(__file__).parent.parent, - stderr=subprocess.DEVNULL, - ) - .decode("ascii") - .strip() - ) - return git_version - except Exception: - return "unknown" - - -def generate_build_meta(): - """Generate build metadata file.""" - build_meta_file = Path(__file__).parent / "flashinfer_jit_cache" / "_build_meta.py" - version = get_version() - git_version = get_git_version() - with open(build_meta_file, "w") as f: - f.write('"""Build metadata for flashinfer-jit-cache package."""\n') - f.write(f'__version__ = "{version}"\n') - f.write(f'__git_version__ = "{git_version}"\n') - - -class PlatformSpecificBdistWheel(bdist_wheel): - """Custom wheel builder that uses py_limited_api for cp39+.""" - - def finalize_options(self): - super().finalize_options() - # Force platform-specific wheel (not pure Python) - self.root_is_pure = False - # Use py_limited_api for cp39 (Python 3.9+) - self.py_limited_api = "cp39" - - def get_tag(self): - # Use py_limited_api tags - python_tag = "cp39" - abi_tag = "abi3" # Stable ABI tag - - # Get platform tag - machine = platform.machine() - if platform.system() == "Linux": - # Use manylinux_2_28 as specified - if machine == "x86_64": - plat_tag = "manylinux_2_28_x86_64" - elif machine == "aarch64": - plat_tag = "manylinux_2_28_aarch64" - else: - plat_tag = f"linux_{machine}" - else: - # For non-Linux platforms, use the default - import distutils.util - - plat_tag = distutils.util.get_platform().replace("-", "_").replace(".", "_") - - return python_tag, abi_tag, plat_tag - - -if __name__ == "__main__": - generate_build_meta() - setup( - name="flashinfer-jit-cache", - version=get_version(), - description="Pre-compiled AOT modules for FlashInfer", - long_description="This package contains pre-compiled AOT modules for FlashInfer. It provides all necessary compiled shared libraries (.so files) for optimized inference operations.", - long_description_content_type="text/plain", - author="FlashInfer team", - maintainer="FlashInfer team", - url="https://github.com/flashinfer-ai/flashinfer", - project_urls={ - "Homepage": "https://github.com/flashinfer-ai/flashinfer", - "Documentation": "https://github.com/flashinfer-ai/flashinfer", - "Repository": "https://github.com/flashinfer-ai/flashinfer", - "Issue Tracker": "https://github.com/flashinfer-ai/flashinfer/issues", - }, - packages=find_packages(), - package_data={ - "flashinfer_jit_cache": ["jit_cache/**/*.so"], - }, - include_package_data=True, - python_requires=">=3.9", - classifiers=[ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "Operating System :: POSIX :: Linux", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", - "Topic :: Software Development :: Libraries :: Python Modules", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - ], - license="Apache-2.0", - cmdclass={"bdist_wheel": PlatformSpecificBdistWheel}, - ) diff --git a/flashinfer/build_utils.py b/flashinfer/build_utils.py new file mode 100644 index 0000000000..726a628204 --- /dev/null +++ b/flashinfer/build_utils.py @@ -0,0 +1,46 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Shared build utilities for flashinfer packages.""" + +import subprocess +from pathlib import Path +from typing import Optional + + +def get_git_version(cwd: Optional[Path] = None) -> str: + """ + Get git commit hash. + + Args: + cwd: Working directory for git command. If None, uses current directory. + + Returns: + Git commit hash or "unknown" if git is not available. + """ + try: + git_version = ( + subprocess.check_output( + ["git", "rev-parse", "HEAD"], + cwd=cwd, + stderr=subprocess.DEVNULL, + ) + .decode("ascii") + .strip() + ) + return git_version + except Exception: + return "unknown" diff --git a/pyproject.toml b/pyproject.toml index ef13baa13f..2206a11afe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ license-files = ["LICENSE", "licenses/*"] [build-system] requires = ["setuptools>=77", "packaging>=24"] -build-backend = "custom_backend" +build-backend = "build_backend" backend-path = ["."] [tool.codespell] diff --git a/scripts/update_whl_index.py b/scripts/update_whl_index.py index 01b6187996..2f16dc869d 100644 --- a/scripts/update_whl_index.py +++ b/scripts/update_whl_index.py @@ -129,30 +129,34 @@ def update_index( # Update index.html index_file = index_dir / "index.html" - # Read existing content to avoid duplicates - existing_links = set() + # Read existing links to avoid duplicates + links = set() if index_file.exists(): with index_file.open("r") as f: - existing_links = set(f.readlines()) - else: - # Create new index file with HTML header + content = f.read() + # Simple regex to extract the tags + links.update(re.findall(r'.*?
\n', content)) + + # Create and add new link + new_link = f'{wheel_path.name}
\n' + is_new = new_link not in links + if is_new: + links.add(new_link) + + # Write the complete, valid HTML file with index_file.open("w") as f: f.write("\n") f.write("\n") f.write(f"Links for {package}\n") f.write("\n") f.write(f"

Links for {package}

\n") - print(f" 📝 Created new index file: {index_dir}/index.html") - - # Create new link - new_link = f'{wheel_path.name}
\n' - - if new_link in existing_links: - print(f" â„šī¸ Already in index: {index_dir}/index.html") - else: - with index_file.open("a") as f: - f.write(new_link) + for link in sorted(list(links)): + f.write(link) + f.write("\n") + f.write("\n") print(f" ✅ Added to index: {index_dir}/index.html") + else: + print(f" â„šī¸ Already in index: {index_dir}/index.html") print(f" đŸ“Ļ Package: {package}") print(f" 🔖 Version: {info['version']}") diff --git a/setup.py b/setup.py deleted file mode 100755 index 8fdc0a76a5..0000000000 --- a/setup.py +++ /dev/null @@ -1,90 +0,0 @@ -""" -Copyright (c) 2023 by FlashInfer team. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -from pathlib import Path -from typing import List, Mapping - -import setuptools - -root = Path(__file__).parent.resolve() - - -def write_if_different(path: Path, content: str) -> None: - if path.exists() and path.read_text() == content: - return - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(content) - - -def get_version(): - import os - - package_version = (root / "version.txt").read_text().strip() - dev_suffix = os.environ.get("FLASHINFER_DEV_RELEASE_SUFFIX", "") - if dev_suffix: - package_version = f"{package_version}.dev{dev_suffix}" - return package_version - - -def get_git_version(): - """Get git commit hash.""" - import subprocess - - try: - git_version = ( - subprocess.check_output( - ["git", "rev-parse", "HEAD"], cwd=root, stderr=subprocess.DEVNULL - ) - .decode("ascii") - .strip() - ) - return git_version - except Exception: - return "unknown" - - -def generate_build_meta() -> None: - build_meta_str = f"__version__ = {get_version()!r}\n" - build_meta_str += f"__git_version__ = {get_git_version()!r}\n" - write_if_different(root / "flashinfer" / "_build_meta.py", build_meta_str) - - -ext_modules: List[setuptools.Extension] = [] -cmdclass: Mapping[str, type[setuptools.Command]] = {} -install_requires = [ - "numpy", - "torch", - "ninja", - "requests", - "nvidia-ml-py", - "einops", - "click", - "tqdm", - "tabulate", - "apache-tvm-ffi==0.1.0b15", - "packaging>=24.2", - "nvidia-cudnn-frontend>=1.13.0", - "nvidia-cutlass-dsl>=4.2.1", -] -generate_build_meta() - - -setuptools.setup( - version=get_version(), - ext_modules=ext_modules, - cmdclass=cmdclass, - install_requires=install_requires, -) From 13b44e51b5543310f543e6e22c198226d40882ac Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 5 Oct 2025 23:31:10 -0400 Subject: [PATCH 11/41] upd --- build_backend.py | 2 +- build_utils.py | 46 +++++++++++++++++++++++++++ flashinfer-cubin/build_backend.py | 2 +- flashinfer-jit-cache/build_backend.py | 2 +- 4 files changed, 49 insertions(+), 3 deletions(-) create mode 100644 build_utils.py diff --git a/build_backend.py b/build_backend.py index 96465cf833..2ae7f82c57 100644 --- a/build_backend.py +++ b/build_backend.py @@ -19,7 +19,7 @@ from pathlib import Path from setuptools import build_meta as orig -from flashinfer.build_utils import get_git_version +from build_utils import get_git_version _root = Path(__file__).parent.resolve() _data_dir = _root / "flashinfer" / "data" diff --git a/build_utils.py b/build_utils.py new file mode 100644 index 0000000000..726a628204 --- /dev/null +++ b/build_utils.py @@ -0,0 +1,46 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Shared build utilities for flashinfer packages.""" + +import subprocess +from pathlib import Path +from typing import Optional + + +def get_git_version(cwd: Optional[Path] = None) -> str: + """ + Get git commit hash. + + Args: + cwd: Working directory for git command. If None, uses current directory. + + Returns: + Git commit hash or "unknown" if git is not available. + """ + try: + git_version = ( + subprocess.check_output( + ["git", "rev-parse", "HEAD"], + cwd=cwd, + stderr=subprocess.DEVNULL, + ) + .decode("ascii") + .strip() + ) + return git_version + except Exception: + return "unknown" diff --git a/flashinfer-cubin/build_backend.py b/flashinfer-cubin/build_backend.py index d08b6db97c..62512ee09b 100644 --- a/flashinfer-cubin/build_backend.py +++ b/flashinfer-cubin/build_backend.py @@ -11,7 +11,7 @@ # Add parent directory to path to import artifacts module sys.path.insert(0, str(Path(__file__).parent.parent)) -from flashinfer.build_utils import get_git_version +from build_utils import get_git_version # add flashinfer._build_meta, always override to ensure version is up-to-date diff --git a/flashinfer-jit-cache/build_backend.py b/flashinfer-jit-cache/build_backend.py index 3ee611a8d4..b0e0636806 100644 --- a/flashinfer-jit-cache/build_backend.py +++ b/flashinfer-jit-cache/build_backend.py @@ -24,7 +24,7 @@ # Add parent directory to path to import flashinfer modules sys.path.insert(0, str(Path(__file__).parent.parent)) -from flashinfer.build_utils import get_git_version +from build_utils import get_git_version # add flashinfer._build_meta, always override to ensure version is up-to-date From 95f194df6ff8053fa29069b3cf2df885645e4329 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 5 Oct 2025 23:31:43 -0400 Subject: [PATCH 12/41] remove unused files --- flashinfer/build_utils.py | 46 --------------------------------------- 1 file changed, 46 deletions(-) delete mode 100644 flashinfer/build_utils.py diff --git a/flashinfer/build_utils.py b/flashinfer/build_utils.py deleted file mode 100644 index 726a628204..0000000000 --- a/flashinfer/build_utils.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -Copyright (c) 2025 by FlashInfer team. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -"""Shared build utilities for flashinfer packages.""" - -import subprocess -from pathlib import Path -from typing import Optional - - -def get_git_version(cwd: Optional[Path] = None) -> str: - """ - Get git commit hash. - - Args: - cwd: Working directory for git command. If None, uses current directory. - - Returns: - Git commit hash or "unknown" if git is not available. - """ - try: - git_version = ( - subprocess.check_output( - ["git", "rev-parse", "HEAD"], - cwd=cwd, - stderr=subprocess.DEVNULL, - ) - .decode("ascii") - .strip() - ) - return git_version - except Exception: - return "unknown" From d3efce70206db772de5155fcbca0bde9191032e4 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 6 Oct 2025 00:03:18 -0400 Subject: [PATCH 13/41] upd --- flashinfer-jit-cache/build_backend.py | 44 +++++++++++++++------------ 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/flashinfer-jit-cache/build_backend.py b/flashinfer-jit-cache/build_backend.py index b0e0636806..b0066a797c 100644 --- a/flashinfer-jit-cache/build_backend.py +++ b/flashinfer-jit-cache/build_backend.py @@ -40,24 +40,36 @@ f.write(f'__git_version__ = "{git_version}"\n') -def get_version(): +def _create_build_metadata(): + """Create build metadata file with version information.""" version_file = Path(__file__).parent.parent / "version.txt" if version_file.exists(): with open(version_file, "r") as f: version = f.read().strip() else: - version = "0.0.0+unknown" + version = "0.0.0" + + # Add dev suffix if specified + dev_suffix = os.environ.get("FLASHINFER_DEV_RELEASE_SUFFIX", "") + if dev_suffix: + version = f"{version}.dev{dev_suffix}" + + # Get git version + git_version = get_git_version() # Append CUDA version suffix if available cuda_suffix = os.environ.get("CUDA_VERSION_SUFFIX", "") if cuda_suffix: - # Replace + with . for proper version formatting - if "+" in version: - base_version, local = version.split("+", 1) - version = f"{base_version}+{cuda_suffix}.{local}" - else: - version = f"{version}+{cuda_suffix}" + # Use + to create a local version identifier that will appear in wheel name + version = f"{version}+{cuda_suffix}" + build_meta_file = Path(__file__).parent / "flashinfer_jit_cache" / "_build_meta.py" + with open(build_meta_file, "w") as f: + f.write('"""Build metadata for flashinfer-jit-cache package."""\n') + f.write(f'__version__ = "{version}"\n') + f.write(f'__git_version__ = "{git_version}"\n') + + print(f"Created build metadata file with version {version}") return version @@ -82,8 +94,7 @@ def compile_jit_cache(output_dir: Path, verbose: bool = True): ) -def _prepare_build(): - """Shared preparation logic for both wheel and editable builds.""" +def _build_aot_modules(): # First, ensure AOT modules are compiled aot_package_dir = Path(__file__).parent / "flashinfer_jit_cache" / "jit_cache" aot_package_dir.mkdir(parents=True, exist_ok=True) @@ -103,16 +114,11 @@ def _prepare_build(): print(f"Failed to compile AOT modules: {e}") raise - # Create build metadata file with version information - package_dir = Path(__file__).parent / "flashinfer_jit_cache" - build_meta_file = package_dir / "_build_meta.py" - version = get_version() - - with open(build_meta_file, "w") as f: - f.write('"""Build metadata for flashinfer-jit-cache package."""\n') - f.write(f'__version__ = "{version}"\n') - print(f"Created build metadata file with version {version}") +def _prepare_build(): + """Shared preparation logic for both wheel and editable builds.""" + _create_build_metadata() + _build_aot_modules() class PlatformSpecificBdistWheel(bdist_wheel): From a0c0f8963aeffc53e67b47442b0fbeb3d8fe81ca Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 6 Oct 2025 00:23:24 -0400 Subject: [PATCH 14/41] upd --- flashinfer-jit-cache/build_backend.py | 47 +++++++++++++++++---------- scripts/update_whl_index.py | 21 +++++++++--- 2 files changed, 46 insertions(+), 22 deletions(-) diff --git a/flashinfer-jit-cache/build_backend.py b/flashinfer-jit-cache/build_backend.py index b0066a797c..10ada07e84 100644 --- a/flashinfer-jit-cache/build_backend.py +++ b/flashinfer-jit-cache/build_backend.py @@ -155,29 +155,31 @@ def get_tag(self): return python_tag, abi_tag, plat_tag +class _MonkeyPatchBdistWheel: + """Context manager to temporarily replace bdist_wheel with our custom class.""" + + def __enter__(self): + from setuptools.command import bdist_wheel as setuptools_bdist_wheel + + self.original_bdist_wheel = setuptools_bdist_wheel.bdist_wheel + setuptools_bdist_wheel.bdist_wheel = PlatformSpecificBdistWheel + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + from setuptools.command import bdist_wheel as setuptools_bdist_wheel + + setuptools_bdist_wheel.bdist_wheel = self.original_bdist_wheel + + def build_wheel(wheel_directory, config_settings=None, metadata_directory=None): """Build wheel with custom AOT module compilation.""" print("Building flashinfer-jit-cache wheel...") _prepare_build() - # Inject custom bdist_wheel class - import setuptools - - original_cmdclass = getattr(setuptools, "_GLOBAL_CMDCLASS", {}) - setuptools._GLOBAL_CMDCLASS = { - **original_cmdclass, - "bdist_wheel": PlatformSpecificBdistWheel, - } - - try: - # Now build the wheel using setuptools - result = _orig.build_wheel(wheel_directory, config_settings, metadata_directory) - finally: - # Restore original cmdclass - setuptools._GLOBAL_CMDCLASS = original_cmdclass - - return result + with _MonkeyPatchBdistWheel(): + return _orig.build_wheel(wheel_directory, config_settings, metadata_directory) def build_editable(wheel_directory, config_settings=None, metadata_directory=None): @@ -196,9 +198,18 @@ def build_editable(wheel_directory, config_settings=None, metadata_directory=Non return result +def prepare_metadata_for_build_wheel(metadata_directory, config_settings=None): + """Prepare metadata with platform-specific wheel tags.""" + _create_build_metadata() + + with _MonkeyPatchBdistWheel(): + return _orig.prepare_metadata_for_build_wheel( + metadata_directory, config_settings + ) + + # Export the required interface get_requires_for_build_wheel = _orig.get_requires_for_build_wheel -prepare_metadata_for_build_wheel = _orig.prepare_metadata_for_build_wheel get_requires_for_build_editable = getattr( _orig, "get_requires_for_build_editable", None ) diff --git a/scripts/update_whl_index.py b/scripts/update_whl_index.py index 2f16dc869d..cca868cf34 100644 --- a/scripts/update_whl_index.py +++ b/scripts/update_whl_index.py @@ -75,6 +75,7 @@ def update_index( output_dir: str = "whl", base_url: str = "https://github.com/flashinfer-ai/flashinfer/releases/download", release_tag: Optional[str] = None, + nightly: bool = False, ): """ Update wheel index from dist directory. @@ -84,6 +85,7 @@ def update_index( output_dir: Output directory for index files base_url: Base URL for wheel downloads release_tag: GitHub release tag (e.g., 'nightly' or 'v0.3.1') + nightly: If True, update index to whl/nightly subdirectory """ dist_path = pathlib.Path(dist_dir) if not dist_path.exists(): @@ -113,12 +115,17 @@ def update_index( package = info["package"] cuda = info["cuda"] + # Add nightly subdirectory if nightly flag is set + base_output = pathlib.Path(output_dir) + if nightly: + base_output = base_output / "nightly" + if cuda: - # CUDA-specific index for jit-cache: whl/cu130/flashinfer-jit-cache/ - index_dir = pathlib.Path(output_dir) / f"cu{cuda}" / package + # CUDA-specific index for jit-cache: whl/nightly/cu130/flashinfer-jit-cache/ + index_dir = base_output / f"cu{cuda}" / package else: - # No CUDA version for python/cubin: whl/flashinfer-python/ - index_dir = pathlib.Path(output_dir) / package + # No CUDA version for python/cubin: whl/nightly/flashinfer-python/ + index_dir = base_output / package index_dir.mkdir(parents=True, exist_ok=True) @@ -188,6 +195,11 @@ def main(): "--release-tag", help="GitHub release tag (e.g., 'nightly' or 'v0.3.1'). If not specified, will be derived from version.", ) + parser.add_argument( + "--nightly", + action="store_true", + help="Update index to whl/nightly subdirectory for nightly releases", + ) args = parser.parse_args() @@ -196,6 +208,7 @@ def main(): output_dir=args.output_dir, base_url=args.base_url, release_tag=args.release_tag, + nightly=args.nightly, ) From 80a6f5e21401c2282fa7b0e552e6086b5a43e50c Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 6 Oct 2025 00:36:00 -0400 Subject: [PATCH 15/41] upd --- .github/workflows/nightly-release.yml | 3 ++- build_backend.py | 27 +++++++++++++-------------- pyproject.toml | 3 +++ 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/.github/workflows/nightly-release.yml b/.github/workflows/nightly-release.yml index b167a3a525..4c3df9184a 100644 --- a/.github/workflows/nightly-release.yml +++ b/.github/workflows/nightly-release.yml @@ -310,7 +310,8 @@ jobs: python3 scripts/update_whl_index.py \ --dist-dir dist \ --output-dir flashinfer-whl \ - --release-tag "${{ needs.setup.outputs.release_tag }}" + --release-tag "${{ needs.setup.outputs.release_tag }}" \ + --nightly - name: Push wheel index run: | diff --git a/build_backend.py b/build_backend.py index 2ae7f82c57..b5568308cb 100644 --- a/build_backend.py +++ b/build_backend.py @@ -25,13 +25,6 @@ _data_dir = _root / "flashinfer" / "data" -def write_if_different(path: Path, content: str) -> None: - if path.exists() and path.read_text() == content: - return - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(content) - - def get_version(): package_version = (_root / "version.txt").read_text().strip() dev_suffix = os.environ.get("FLASHINFER_DEV_RELEASE_SUFFIX", "") @@ -40,10 +33,19 @@ def get_version(): return package_version -def generate_build_meta() -> None: - build_meta_str = f"__version__ = {get_version()!r}\n" - build_meta_str += f"__git_version__ = {get_git_version(cwd=_root)!r}\n" - write_if_different(_root / "flashinfer" / "_build_meta.py", build_meta_str) +# Create _build_meta.py at import time so setuptools can read the version +build_meta_file = _root / "flashinfer" / "_build_meta.py" +with open(build_meta_file, "w") as f: + f.write('"""Build metadata for flashinfer package."""\n') + f.write(f'__version__ = "{get_version()}"\n') + f.write(f'__git_version__ = "{get_git_version(cwd=_root)}"\n') + + +def write_if_different(path: Path, content: str) -> None: + if path.exists() and path.read_text() == content: + return + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(content) def _create_data_dir(): @@ -82,19 +84,16 @@ def _prepare_for_sdist(): def get_requires_for_build_wheel(config_settings=None): - generate_build_meta() _prepare_for_wheel() return [] def get_requires_for_build_sdist(config_settings=None): - generate_build_meta() _prepare_for_sdist() return [] def get_requires_for_build_editable(config_settings=None): - generate_build_meta() _prepare_for_editable() return [] diff --git a/pyproject.toml b/pyproject.toml index 2206a11afe..613a9fa8ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,9 @@ skip = [ [tool.setuptools] include-package-data = false +[tool.setuptools.dynamic] +version = {attr = "flashinfer.__version__"} + [tool.setuptools.packages.find] where = ["."] include = ["flashinfer*"] From e4bae87127008c4e7e30ca7807a7fd318c4ae47f Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 6 Oct 2025 00:40:45 -0400 Subject: [PATCH 16/41] upd --- pyproject.toml | 2 +- scripts/update_whl_index.py | 47 ++++++++++++++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 613a9fa8ae..6cdd5deaa3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ skip = [ include-package-data = false [tool.setuptools.dynamic] -version = {attr = "flashinfer.__version__"} +version = {attr = "flashinfer._build_meta.__version__"} [tool.setuptools.packages.find] where = ["."] diff --git a/scripts/update_whl_index.py b/scripts/update_whl_index.py index cca868cf34..474ec61ea9 100644 --- a/scripts/update_whl_index.py +++ b/scripts/update_whl_index.py @@ -70,6 +70,40 @@ def compute_sha256(file_path: pathlib.Path) -> str: return hashlib.sha256(f.read()).hexdigest() +def generate_directory_index(directory: pathlib.Path): + """Generate index.html for a directory listing its subdirectories.""" + # Get all subdirectories + subdirs = sorted([d for d in directory.iterdir() if d.is_dir()]) + + if not subdirs: + return + + index_file = directory / "index.html" + + # Generate HTML for directory listing + with index_file.open("w") as f: + f.write("\n") + f.write("\n") + f.write(f"Index of {directory.name or 'root'}\n") + f.write("\n") + f.write(f"

Index of {directory.name or 'root'}

\n") + + for subdir in subdirs: + f.write(f'{subdir.name}/
\n') + + f.write("\n") + f.write("\n") + + +def update_parent_indices(leaf_dir: pathlib.Path, root_dir: pathlib.Path): + """Recursively update index.html for all parent directories.""" + current = leaf_dir.parent + + while current >= root_dir and current != current.parent: + generate_directory_index(current) + current = current.parent + + def update_index( dist_dir: str = "dist", output_dir: str = "whl", @@ -85,7 +119,7 @@ def update_index( output_dir: Output directory for index files base_url: Base URL for wheel downloads release_tag: GitHub release tag (e.g., 'nightly' or 'v0.3.1') - nightly: If True, update index to whl/nightly subdirectory + nightly: If True, update index to whl/nightly subdirectory for nightly releases """ dist_path = pathlib.Path(dist_dir) if not dist_path.exists(): @@ -99,6 +133,9 @@ def update_index( print(f"Found {len(wheels)} wheel file(s)") + # Track all directories that need parent index updates + created_dirs = set() + for wheel_path in wheels: print(f"\nProcessing: {wheel_path.name}") @@ -128,6 +165,7 @@ def update_index( index_dir = base_output / package index_dir.mkdir(parents=True, exist_ok=True) + created_dirs.add(index_dir) # Construct download URL tag = release_tag or f"v{info['version'].split('+')[0].split('.dev')[0]}" @@ -171,6 +209,13 @@ def update_index( print(f" 🎮 CUDA: cu{cuda}") print(f" 📍 URL: {download_url}") + # Update parent directory indices + print("\n📂 Updating parent directory indices...") + root_output = pathlib.Path(output_dir) + for leaf_dir in created_dirs: + update_parent_indices(leaf_dir, root_output) + print(" ✅ Parent indices updated") + def main(): parser = argparse.ArgumentParser( From ff968433672b0c871c5a6721ad897412000c272b Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 6 Oct 2025 00:51:53 -0400 Subject: [PATCH 17/41] upd --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6cdd5deaa3..93671270d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dynamic = ["dependencies", "version"] license-files = ["LICENSE", "licenses/*"] [build-system] -requires = ["setuptools>=77", "packaging>=24"] +requires = ["setuptools>=77", "packaging>=24", "apache-tvm-ffi==0.1.0b15"] build-backend = "build_backend" backend-path = ["."] From 43bf95c69852e803c27b98c375431599730ba348 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 6 Oct 2025 01:26:37 -0400 Subject: [PATCH 18/41] add unittest following build --- .github/workflows/nightly-release.yml | 63 ++++++++++++++++++++++++++- scripts/task_jit_run_tests_part1.sh | 5 ++- scripts/task_jit_run_tests_part2.sh | 5 ++- scripts/task_jit_run_tests_part3.sh | 5 ++- scripts/task_jit_run_tests_part4.sh | 5 ++- scripts/task_jit_run_tests_part5.sh | 5 ++- scripts/task_test_nightly_build.sh | 32 ++++++++++++++ 7 files changed, 114 insertions(+), 6 deletions(-) create mode 100755 scripts/task_test_nightly_build.sh diff --git a/.github/workflows/nightly-release.yml b/.github/workflows/nightly-release.yml index 4c3df9184a..2f3cdd6c0d 100644 --- a/.github/workflows/nightly-release.yml +++ b/.github/workflows/nightly-release.yml @@ -277,8 +277,69 @@ jobs: done done + test-nightly-build: + needs: [setup, build-flashinfer-python, build-flashinfer-cubin, build-flashinfer-jit-cache] + strategy: + fail-fast: false + matrix: + cuda: ["12.9", "13.0"] + test-shard: [1, 2, 3, 4, 5] + runs-on: [self-hosted, G5, X64] + + steps: + - name: Display Machine Information + run: | + echo "CPU: $(nproc) cores, $(lscpu | grep 'Model name' | cut -d':' -f2 | xargs)" + echo "RAM: $(free -h | awk '/^Mem:/ {print $7 " available out of " $2}')" + echo "Disk: $(df -h / | awk 'NR==2 {print $4 " available out of " $2}')" + echo "Architecture: $(uname -m)" + nvidia-smi + + - name: Checkout code + uses: actions/checkout@v4 + with: + submodules: true + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Download flashinfer-python artifact + uses: actions/download-artifact@v4 + with: + name: flashinfer-python-sdist + path: dist-python/ + + - name: Download flashinfer-cubin artifact + uses: actions/download-artifact@v4 + with: + name: flashinfer-cubin-wheel + path: dist-cubin/ + + - name: Download flashinfer-jit-cache artifact + uses: actions/download-artifact@v4 + with: + name: jit-cache-cu${{ matrix.cuda == '12.9' && '129' || '130' }}-x86_64 + path: dist-jit-cache/ + + - name: Get Docker image tag + id: docker-tag + run: | + CUDA_VERSION="cu${{ matrix.cuda == '12.9' && '129' || '130' }}" + DOCKER_TAG=$(grep "flashinfer/flashinfer-ci-${CUDA_VERSION}" ci/docker-tags.yml | cut -d':' -f2 | tr -d ' ') + echo "cuda_version=${CUDA_VERSION}" >> $GITHUB_OUTPUT + echo "tag=${DOCKER_TAG}" >> $GITHUB_OUTPUT + + - name: Run nightly build tests in Docker (shard ${{ matrix.test-shard }}) + env: + CUDA_VISIBLE_DEVICES: 0 + run: | + DOCKER_IMAGE="flashinfer/flashinfer-ci-${{ steps.docker-tag.outputs.cuda_version }}:${{ steps.docker-tag.outputs.tag }}" + bash ci/bash.sh ${DOCKER_IMAGE} -e TEST_SHARD ${{ matrix.test-shard }} ./scripts/task_test_nightly_build.sh + update-wheel-index: - needs: [setup, create-release] + needs: [setup, create-release, test-nightly-build] runs-on: ubuntu-latest steps: - name: Checkout flashinfer repo diff --git a/scripts/task_jit_run_tests_part1.sh b/scripts/task_jit_run_tests_part1.sh index 38f89999a3..ff7ac2663e 100755 --- a/scripts/task_jit_run_tests_part1.sh +++ b/scripts/task_jit_run_tests_part1.sh @@ -4,8 +4,11 @@ set -eo pipefail set -x : ${MAX_JOBS:=$(nproc)} : ${CUDA_VISIBLE_DEVICES:=0} +: ${SKIP_INSTALL:=0} -pip install -e . -v +if [ "$SKIP_INSTALL" = "0" ]; then + pip install -e . -v +fi # pytest -s tests/gemm/test_group_gemm.py pytest -s tests/attention/test_logits_cap.py diff --git a/scripts/task_jit_run_tests_part2.sh b/scripts/task_jit_run_tests_part2.sh index 7fc15fef25..9b93133f5b 100755 --- a/scripts/task_jit_run_tests_part2.sh +++ b/scripts/task_jit_run_tests_part2.sh @@ -4,8 +4,11 @@ set -eo pipefail set -x : ${MAX_JOBS:=$(nproc)} : ${CUDA_VISIBLE_DEVICES:=0} +: ${SKIP_INSTALL:=0} -pip install -e . -v +if [ "$SKIP_INSTALL" = "0" ]; then + pip install -e . -v +fi pytest -s tests/utils/test_block_sparse.py pytest -s tests/utils/test_jit_example.py diff --git a/scripts/task_jit_run_tests_part3.sh b/scripts/task_jit_run_tests_part3.sh index da342eec19..da82f5af1d 100755 --- a/scripts/task_jit_run_tests_part3.sh +++ b/scripts/task_jit_run_tests_part3.sh @@ -4,7 +4,10 @@ set -eo pipefail set -x : ${MAX_JOBS:=$(nproc)} : ${CUDA_VISIBLE_DEVICES:=0} +: ${SKIP_INSTALL:=0} -pip install -e . -v +if [ "$SKIP_INSTALL" = "0" ]; then + pip install -e . -v +fi pytest -s tests/utils/test_sampling.py diff --git a/scripts/task_jit_run_tests_part4.sh b/scripts/task_jit_run_tests_part4.sh index cea153b6b9..22e21c1d21 100755 --- a/scripts/task_jit_run_tests_part4.sh +++ b/scripts/task_jit_run_tests_part4.sh @@ -4,8 +4,11 @@ set -eo pipefail set -x : ${MAX_JOBS:=$(nproc)} : ${CUDA_VISIBLE_DEVICES:=0} +: ${SKIP_INSTALL:=0} -pip install -e . -v +if [ "$SKIP_INSTALL" = "0" ]; then + pip install -e . -v +fi export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True # avoid memory fragmentation pytest -s tests/attention/test_deepseek_mla.py diff --git a/scripts/task_jit_run_tests_part5.sh b/scripts/task_jit_run_tests_part5.sh index 58bb21babd..5606673aef 100755 --- a/scripts/task_jit_run_tests_part5.sh +++ b/scripts/task_jit_run_tests_part5.sh @@ -4,7 +4,10 @@ set -eo pipefail set -x : ${MAX_JOBS:=$(nproc)} : ${CUDA_VISIBLE_DEVICES:=0} +: ${SKIP_INSTALL:=0} -pip install -e . -v +if [ "$SKIP_INSTALL" = "0" ]; then + pip install -e . -v +fi pytest -s tests/utils/test_logits_processor.py diff --git a/scripts/task_test_nightly_build.sh b/scripts/task_test_nightly_build.sh new file mode 100755 index 0000000000..6428a082eb --- /dev/null +++ b/scripts/task_test_nightly_build.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +set -eo pipefail +set -x + +# This script installs nightly build packages and runs tests +# Expected dist directories to be in current directory or specified via env vars + +: ${TEST_SHARD:=1} +: ${CUDA_VISIBLE_DEVICES:=0} +: ${DIST_CUBIN_DIR:=dist-cubin} +: ${DIST_JIT_CACHE_DIR:=dist-jit-cache} +: ${DIST_PYTHON_DIR:=dist-python} + +# Install flashinfer packages +echo "Installing flashinfer-cubin from ${DIST_CUBIN_DIR}..." +pip install ${DIST_CUBIN_DIR}/*.whl + +echo "Installing flashinfer-jit-cache from ${DIST_JIT_CACHE_DIR}..." +pip install ${DIST_JIT_CACHE_DIR}/*.whl + +echo "Installing flashinfer-python from ${DIST_PYTHON_DIR}..." +pip install ${DIST_PYTHON_DIR}/*.tar.gz + +# Verify installation +echo "Verifying installation..." +python -m flashinfer show-config + +# Run test shard +echo "Running test shard ${TEST_SHARD}..." +export SKIP_INSTALL=1 +bash scripts/task_jit_run_tests_part${TEST_SHARD}.sh From a8417211f21ca224bbd43febf1c3f5c522b72699 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 6 Oct 2025 01:38:59 -0400 Subject: [PATCH 19/41] upd --- .github/workflows/nightly-release.yml | 6 ------ scripts/task_test_nightly_build.sh | 4 ++++ 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/.github/workflows/nightly-release.yml b/.github/workflows/nightly-release.yml index 2f3cdd6c0d..8b1945bed8 100644 --- a/.github/workflows/nightly-release.yml +++ b/.github/workflows/nightly-release.yml @@ -293,18 +293,12 @@ jobs: echo "RAM: $(free -h | awk '/^Mem:/ {print $7 " available out of " $2}')" echo "Disk: $(df -h / | awk 'NR==2 {print $4 " available out of " $2}')" echo "Architecture: $(uname -m)" - nvidia-smi - name: Checkout code uses: actions/checkout@v4 with: submodules: true - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.10' - - name: Download flashinfer-python artifact uses: actions/download-artifact@v4 with: diff --git a/scripts/task_test_nightly_build.sh b/scripts/task_test_nightly_build.sh index 6428a082eb..3ba666545d 100755 --- a/scripts/task_test_nightly_build.sh +++ b/scripts/task_test_nightly_build.sh @@ -12,6 +12,10 @@ set -x : ${DIST_JIT_CACHE_DIR:=dist-jit-cache} : ${DIST_PYTHON_DIR:=dist-python} +# Display GPU information (running inside Docker container with GPU access) +echo "=== GPU Information ===" +nvidia-smi + # Install flashinfer packages echo "Installing flashinfer-cubin from ${DIST_CUBIN_DIR}..." pip install ${DIST_CUBIN_DIR}/*.whl From 05ce648b408ab17c2c31ad7b498b3c28d15b9369 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 6 Oct 2025 01:53:02 -0400 Subject: [PATCH 20/41] upd --- flashinfer-cubin/build_backend.py | 14 ++------------ flashinfer-jit-cache/build_backend.py | 14 ++------------ flashinfer/jit/env.py | 17 +++++++++++++---- 3 files changed, 17 insertions(+), 28 deletions(-) diff --git a/flashinfer-cubin/build_backend.py b/flashinfer-cubin/build_backend.py index 62512ee09b..676396dce9 100644 --- a/flashinfer-cubin/build_backend.py +++ b/flashinfer-cubin/build_backend.py @@ -13,18 +13,8 @@ from build_utils import get_git_version - -# add flashinfer._build_meta, always override to ensure version is up-to-date -build_meta_file = Path(__file__).parent.parent / "flashinfer" / "_build_meta.py" -version_file = Path(__file__).parent.parent / "version.txt" -if version_file.exists(): - with open(version_file, "r") as f: - version = f.read().strip() -git_version = get_git_version(cwd=Path(__file__).parent.parent) -with open(build_meta_file, "w") as f: - f.write('"""Build metadata for flashinfer package."""\n') - f.write(f'__version__ = "{version}"\n') - f.write(f'__git_version__ = "{git_version}"\n') +# Skip version check when building flashinfer-cubin package +os.environ["FLASHINFER_DISABLE_VERSION_CHECK"] = "1" def _download_cubins(): diff --git a/flashinfer-jit-cache/build_backend.py b/flashinfer-jit-cache/build_backend.py index 10ada07e84..e0864337bb 100644 --- a/flashinfer-jit-cache/build_backend.py +++ b/flashinfer-jit-cache/build_backend.py @@ -26,18 +26,8 @@ from build_utils import get_git_version - -# add flashinfer._build_meta, always override to ensure version is up-to-date -build_meta_file = Path(__file__).parent.parent / "flashinfer" / "_build_meta.py" -version_file = Path(__file__).parent.parent / "version.txt" -if version_file.exists(): - with open(version_file, "r") as f: - version = f.read().strip() -git_version = get_git_version(cwd=Path(__file__).parent.parent) -with open(build_meta_file, "w") as f: - f.write('"""Build metadata for flashinfer package."""\n') - f.write(f'__version__ = "{version}"\n') - f.write(f'__git_version__ = "{git_version}"\n') +# Skip version check when building flashinfer-jit-cache package +os.environ["FLASHINFER_DISABLE_VERSION_CHECK"] = "1" def _create_build_metadata(): diff --git a/flashinfer/jit/env.py b/flashinfer/jit/env.py index 057ac97978..97330eba1c 100644 --- a/flashinfer/jit/env.py +++ b/flashinfer/jit/env.py @@ -44,11 +44,16 @@ def _get_cubin_dir(): import flashinfer_cubin flashinfer_cubin_version = flashinfer_cubin.__version__ - if flashinfer_version != flashinfer_cubin_version: + # Allow bypassing version check with environment variable + if ( + not os.getenv("FLASHINFER_DISABLE_VERSION_CHECK") + and flashinfer_version != flashinfer_cubin_version + ): raise RuntimeError( f"flashinfer-cubin version ({flashinfer_cubin_version}) does not match " f"flashinfer version ({flashinfer_version}). " - "Please install the same version of both packages." + "Please install the same version of both packages. " + "Set FLASHINFER_DISABLE_VERSION_CHECK=1 to bypass this check." ) return pathlib.Path(flashinfer_cubin.get_cubin_dir()) @@ -78,11 +83,15 @@ def _get_aot_dir(): flashinfer_jit_cache_version = flashinfer_jit_cache.__version__ # NOTE(Zihao): we don't use exact version match here because the version of flashinfer-jit-cache # contains the CUDA version suffix: e.g. 0.3.1+cu129. - if not flashinfer_jit_cache_version.startswith(flashinfer_version): + # Allow bypassing version check with environment variable + if not os.getenv( + "FLASHINFER_DISABLE_VERSION_CHECK" + ) and not flashinfer_jit_cache_version.startswith(flashinfer_version): raise RuntimeError( f"flashinfer-jit-cache version ({flashinfer_jit_cache_version}) does not match " f"flashinfer version ({flashinfer_version}). " - "Please install the same version of both packages." + "Please install the same version of both packages. " + "Set FLASHINFER_DISABLE_VERSION_CHECK=1 to bypass this check." ) return pathlib.Path(flashinfer_jit_cache.get_jit_cache_dir()) From d09ba32b0b0499e63edec02b29e7c481162bcddc Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 6 Oct 2025 02:03:01 -0400 Subject: [PATCH 21/41] upd --- flashinfer-cubin/build_backend.py | 17 ++++++++++++++--- flashinfer-jit-cache/build_backend.py | 18 ++++++++++++++---- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/flashinfer-cubin/build_backend.py b/flashinfer-cubin/build_backend.py index 676396dce9..2c5b3760f1 100644 --- a/flashinfer-cubin/build_backend.py +++ b/flashinfer-cubin/build_backend.py @@ -6,7 +6,6 @@ import sys from pathlib import Path from setuptools import build_meta as _orig -from setuptools.build_meta import * # Add parent directory to path to import artifacts module sys.path.insert(0, str(Path(__file__).parent.parent)) @@ -93,5 +92,17 @@ def build_editable(wheel_directory, config_settings=None, metadata_directory=Non # Pass through all other hooks get_requires_for_build_wheel = _orig.get_requires_for_build_wheel get_requires_for_build_editable = _orig.get_requires_for_build_editable -prepare_metadata_for_build_wheel = _orig.prepare_metadata_for_build_wheel -prepare_metadata_for_build_editable = _orig.prepare_metadata_for_build_editable + + +def prepare_metadata_for_build_wheel(metadata_directory, config_settings=None): + """Prepare metadata for wheel build, creating build metadata first.""" + _create_build_metadata() + return _orig.prepare_metadata_for_build_wheel(metadata_directory, config_settings) + + +def prepare_metadata_for_build_editable(metadata_directory, config_settings=None): + """Prepare metadata for editable install, creating build metadata first.""" + _create_build_metadata() + return _orig.prepare_metadata_for_build_editable( + metadata_directory, config_settings + ) diff --git a/flashinfer-jit-cache/build_backend.py b/flashinfer-jit-cache/build_backend.py index e0864337bb..ecc95a09f5 100644 --- a/flashinfer-jit-cache/build_backend.py +++ b/flashinfer-jit-cache/build_backend.py @@ -37,7 +37,7 @@ def _create_build_metadata(): with open(version_file, "r") as f: version = f.read().strip() else: - version = "0.0.0" + version = "0.0.0+unknown" # Add dev suffix if specified dev_suffix = os.environ.get("FLASHINFER_DEV_RELEASE_SUFFIX", "") @@ -45,7 +45,7 @@ def _create_build_metadata(): version = f"{version}.dev{dev_suffix}" # Get git version - git_version = get_git_version() + git_version = get_git_version(cwd=Path(__file__).parent.parent) # Append CUDA version suffix if available cuda_suffix = os.environ.get("CUDA_VERSION_SUFFIX", "") @@ -63,7 +63,7 @@ def _create_build_metadata(): return version -def compile_jit_cache(output_dir: Path, verbose: bool = True): +def _compile_jit_cache(output_dir: Path, verbose: bool = True): """Compile AOT modules using flashinfer.aot functions directly.""" from flashinfer import aot @@ -91,7 +91,7 @@ def _build_aot_modules(): try: # Compile AOT modules - compile_jit_cache(aot_package_dir) + _compile_jit_cache(aot_package_dir) # Verify that some modules were actually compiled so_files = list(aot_package_dir.rglob("*.so")) @@ -198,6 +198,16 @@ def prepare_metadata_for_build_wheel(metadata_directory, config_settings=None): ) +def prepare_metadata_for_build_editable(metadata_directory, config_settings=None): + """Prepare metadata for editable install, creating build metadata first.""" + _create_build_metadata() + + with _MonkeyPatchBdistWheel(): + return _orig.prepare_metadata_for_build_editable( + metadata_directory, config_settings + ) + + # Export the required interface get_requires_for_build_wheel = _orig.get_requires_for_build_wheel get_requires_for_build_editable = getattr( From c12d5c422c19762924b4fd5163c461cf40674ca4 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 6 Oct 2025 02:12:43 -0400 Subject: [PATCH 22/41] upd --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 93671270d9..402c6796e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ skip = [ [tool.setuptools] include-package-data = false +py-modules = ["build_backend", "build_utils"] [tool.setuptools.dynamic] version = {attr = "flashinfer._build_meta.__version__"} From b18da8bca1c233193ab425aba602b791f1b77086 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 6 Oct 2025 02:21:43 -0400 Subject: [PATCH 23/41] upd --- pyproject.toml | 2 ++ requirements.txt | 13 +++++++++++++ 2 files changed, 15 insertions(+) create mode 100644 requirements.txt diff --git a/pyproject.toml b/pyproject.toml index 402c6796e6..8359503923 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,10 +43,12 @@ py-modules = ["build_backend", "build_utils"] [tool.setuptools.dynamic] version = {attr = "flashinfer._build_meta.__version__"} +dependencies = {file = ["requirements.txt"]} [tool.setuptools.packages.find] where = ["."] include = ["flashinfer*"] +exclude = ["flashinfer-jit-cache*", "flashinfer-cubin*"] [tool.setuptools.package-dir] "flashinfer.data" = "." diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000..a4e391c38d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,13 @@ +apache-tvm-ffi==0.1.0b15 +click +einops +ninja +numpy +nvidia-cudnn-frontend>=1.13.0 +nvidia-cutlass-dsl>=4.2.1 +nvidia-ml-py +packaging>=24.2 +requests +tabulate +torch +tqdm From 7f6cbee87cde0b4b8d37b2a24a384251e67a5110 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 6 Oct 2025 02:43:22 -0400 Subject: [PATCH 24/41] upd --- build_backend.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/build_backend.py b/build_backend.py index b5568308cb..7d24fcb305 100644 --- a/build_backend.py +++ b/build_backend.py @@ -48,7 +48,7 @@ def write_if_different(path: Path, content: str) -> None: path.write_text(content) -def _create_data_dir(): +def _create_data_dir(use_symlinks=True): _data_dir.mkdir(parents=True, exist_ok=True) def ln(source: str, target: str) -> None: @@ -58,29 +58,47 @@ def ln(source: str, target: str) -> None: if dst.is_symlink(): dst.unlink() elif dst.is_dir(): - dst.rmdir() - dst.symlink_to(src, target_is_directory=True) + shutil.rmtree(dst) + else: + dst.unlink() + + if use_symlinks: + dst.symlink_to(src, target_is_directory=True) + else: + # For wheel/sdist, copy actual files instead of symlinks + if src.exists(): + shutil.copytree(src, dst, symlinks=False, dirs_exist_ok=True) ln("3rdparty/cutlass", "cutlass") ln("3rdparty/spdlog", "spdlog") ln("csrc", "csrc") ln("include", "include") + # Always ensure version.txt is present + version_file = _data_dir / "version.txt" + if not version_file.exists() or not use_symlinks: + shutil.copy(_root / "version.txt", version_file) + def _prepare_for_wheel(): - # Remove data directory + # For wheel, copy actual files instead of symlinks so they are included in the wheel if _data_dir.exists(): shutil.rmtree(_data_dir) + _create_data_dir(use_symlinks=False) def _prepare_for_editable(): - _create_data_dir() + # For editable install, use symlinks so changes are reflected immediately + if _data_dir.exists(): + shutil.rmtree(_data_dir) + _create_data_dir(use_symlinks=True) def _prepare_for_sdist(): - # Remove data directory + # For sdist, copy actual files instead of symlinks so they are included in the tarball if _data_dir.exists(): shutil.rmtree(_data_dir) + _create_data_dir(use_symlinks=False) def get_requires_for_build_wheel(config_settings=None): From 1db6e19926eed82855419674f334e13e5043519a Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 6 Oct 2025 02:52:58 -0400 Subject: [PATCH 25/41] upd --- flashinfer-cubin/build_backend.py | 22 ++++++---------------- flashinfer-jit-cache/build_backend.py | 11 +++++------ 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/flashinfer-cubin/build_backend.py b/flashinfer-cubin/build_backend.py index 2c5b3760f1..7d3dbb83eb 100644 --- a/flashinfer-cubin/build_backend.py +++ b/flashinfer-cubin/build_backend.py @@ -75,34 +75,24 @@ def _create_build_metadata(): return version +# Create build metadata as soon as this module is imported +_create_build_metadata() + + def build_wheel(wheel_directory, config_settings=None, metadata_directory=None): """Build a wheel, downloading cubins first.""" _download_cubins() - _create_build_metadata() return _orig.build_wheel(wheel_directory, config_settings, metadata_directory) def build_editable(wheel_directory, config_settings=None, metadata_directory=None): """Build an editable install, downloading cubins first.""" _download_cubins() - _create_build_metadata() return _orig.build_editable(wheel_directory, config_settings, metadata_directory) # Pass through all other hooks get_requires_for_build_wheel = _orig.get_requires_for_build_wheel get_requires_for_build_editable = _orig.get_requires_for_build_editable - - -def prepare_metadata_for_build_wheel(metadata_directory, config_settings=None): - """Prepare metadata for wheel build, creating build metadata first.""" - _create_build_metadata() - return _orig.prepare_metadata_for_build_wheel(metadata_directory, config_settings) - - -def prepare_metadata_for_build_editable(metadata_directory, config_settings=None): - """Prepare metadata for editable install, creating build metadata first.""" - _create_build_metadata() - return _orig.prepare_metadata_for_build_editable( - metadata_directory, config_settings - ) +prepare_metadata_for_build_wheel = _orig.prepare_metadata_for_build_wheel +prepare_metadata_for_build_editable = _orig.prepare_metadata_for_build_editable diff --git a/flashinfer-jit-cache/build_backend.py b/flashinfer-jit-cache/build_backend.py index ecc95a09f5..6cd8575f82 100644 --- a/flashinfer-jit-cache/build_backend.py +++ b/flashinfer-jit-cache/build_backend.py @@ -63,6 +63,10 @@ def _create_build_metadata(): return version +# Create build metadata as soon as this module is imported +_create_build_metadata() + + def _compile_jit_cache(output_dir: Path, verbose: bool = True): """Compile AOT modules using flashinfer.aot functions directly.""" from flashinfer import aot @@ -107,7 +111,6 @@ def _build_aot_modules(): def _prepare_build(): """Shared preparation logic for both wheel and editable builds.""" - _create_build_metadata() _build_aot_modules() @@ -190,8 +193,6 @@ def build_editable(wheel_directory, config_settings=None, metadata_directory=Non def prepare_metadata_for_build_wheel(metadata_directory, config_settings=None): """Prepare metadata with platform-specific wheel tags.""" - _create_build_metadata() - with _MonkeyPatchBdistWheel(): return _orig.prepare_metadata_for_build_wheel( metadata_directory, config_settings @@ -199,9 +200,7 @@ def prepare_metadata_for_build_wheel(metadata_directory, config_settings=None): def prepare_metadata_for_build_editable(metadata_directory, config_settings=None): - """Prepare metadata for editable install, creating build metadata first.""" - _create_build_metadata() - + """Prepare metadata for editable install.""" with _MonkeyPatchBdistWheel(): return _orig.prepare_metadata_for_build_editable( metadata_directory, config_settings From 23d2d6b9d77de06410edcec2c78f4b708606e6ed Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 6 Oct 2025 02:55:04 -0400 Subject: [PATCH 26/41] upd --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 8359503923..f635e0cdc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,9 @@ exclude = ["flashinfer-jit-cache*", "flashinfer-cubin*"] "flashinfer.data.spdlog" = "3rdparty/spdlog" [tool.setuptools.package-data] +"flashinfer" = [ + "_build_meta.py" +] "flashinfer.data" = [ "csrc/**", "include/**", From 8cf6f6c43dcf1a5be2eced481dc7c5621dc10bee Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 6 Oct 2025 03:45:59 -0400 Subject: [PATCH 27/41] upd --- build_backend.py | 53 +++++++++++++++++++-------- docs/conf.py | 7 +--- flashinfer-cubin/build_backend.py | 10 +++++ flashinfer-jit-cache/build_backend.py | 10 +++++ pyproject.toml | 3 +- 5 files changed, 61 insertions(+), 22 deletions(-) diff --git a/build_backend.py b/build_backend.py index 7d24fcb305..9672dfb2e4 100644 --- a/build_backend.py +++ b/build_backend.py @@ -25,20 +25,48 @@ _data_dir = _root / "flashinfer" / "data" -def get_version(): - package_version = (_root / "version.txt").read_text().strip() +def _create_build_metadata(): + """Create build metadata file with version information.""" + version_file = _root / "version.txt" + if version_file.exists(): + with open(version_file, "r") as f: + version = f.read().strip() + else: + version = "0.0.0+unknown" + + # Add dev suffix if specified dev_suffix = os.environ.get("FLASHINFER_DEV_RELEASE_SUFFIX", "") if dev_suffix: - package_version = f"{package_version}.dev{dev_suffix}" - return package_version + version = f"{version}.dev{dev_suffix}" + # Get git version + git_version = get_git_version(cwd=_root) -# Create _build_meta.py at import time so setuptools can read the version -build_meta_file = _root / "flashinfer" / "_build_meta.py" -with open(build_meta_file, "w") as f: - f.write('"""Build metadata for flashinfer package."""\n') - f.write(f'__version__ = "{get_version()}"\n') - f.write(f'__git_version__ = "{get_git_version(cwd=_root)}"\n') + # Create build metadata in the source tree + package_dir = Path(__file__).parent / "flashinfer" + build_meta_file = package_dir / "_build_meta.py" + + # Check if we're in a git repository + git_dir = Path(__file__).parent / ".git" + in_git_repo = git_dir.exists() + + # If file exists and not in git repo (installing from sdist), keep existing file + if build_meta_file.exists() and not in_git_repo: + print("Build metadata file already exists (not in git repo), keeping it") + return version + + # In git repo (editable) or file doesn't exist, create/update it + with open(build_meta_file, "w") as f: + f.write('"""Build metadata for flashinfer package."""\n') + f.write(f'__version__ = "{version}"\n') + f.write(f'__git_version__ = "{git_version}"\n') + + print(f"Created build metadata file with version {version}") + return version + + +# Create build metadata as soon as this module is imported +_create_build_metadata() def write_if_different(path: Path, content: str) -> None: @@ -74,11 +102,6 @@ def ln(source: str, target: str) -> None: ln("csrc", "csrc") ln("include", "include") - # Always ensure version.txt is present - version_file = _data_dir / "version.txt" - if not version_file.exists() or not use_symlinks: - shutil.copy(_root / "version.txt", version_file) - def _prepare_for_wheel(): # For wheel, copy actual files instead of symlinks so they are included in the wheel diff --git a/docs/conf.py b/docs/conf.py index 2d9ccf5366..8b61d26c8c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,4 +1,3 @@ -from pathlib import Path from typing import Any, List import flashinfer # noqa: F401 @@ -12,7 +11,6 @@ # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -root = Path(__file__).parents[1].resolve() # FlashInfer is installed via pip before building docs autodoc_mock_imports = [ "torch", @@ -28,9 +26,8 @@ author = "FlashInfer Contributors" copyright = f"2023-2025, {author}" -package_version = (root / "version.txt").read_text().strip() -version = package_version -release = package_version +version = flashinfer.__version__ +release = flashinfer.__version__ # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/flashinfer-cubin/build_backend.py b/flashinfer-cubin/build_backend.py index 7d3dbb83eb..0021c1da67 100644 --- a/flashinfer-cubin/build_backend.py +++ b/flashinfer-cubin/build_backend.py @@ -66,6 +66,16 @@ def _create_build_metadata(): package_dir = Path(__file__).parent / "flashinfer_cubin" build_meta_file = package_dir / "_build_meta.py" + # Check if we're in a git repository + git_dir = Path(__file__).parent.parent / ".git" + in_git_repo = git_dir.exists() + + # If file exists and not in git repo (installing from sdist), keep existing file + if build_meta_file.exists() and not in_git_repo: + print("Build metadata file already exists (not in git repo), keeping it") + return version + + # In git repo (editable) or file doesn't exist, create/update it with open(build_meta_file, "w") as f: f.write('"""Build metadata for flashinfer-cubin package."""\n') f.write(f'__version__ = "{version}"\n') diff --git a/flashinfer-jit-cache/build_backend.py b/flashinfer-jit-cache/build_backend.py index 6cd8575f82..b9a1739070 100644 --- a/flashinfer-jit-cache/build_backend.py +++ b/flashinfer-jit-cache/build_backend.py @@ -54,6 +54,16 @@ def _create_build_metadata(): version = f"{version}+{cuda_suffix}" build_meta_file = Path(__file__).parent / "flashinfer_jit_cache" / "_build_meta.py" + # Check if we're in a git repository + git_dir = Path(__file__).parent.parent / ".git" + in_git_repo = git_dir.exists() + + # If file exists and not in git repo (installing from sdist), keep existing file + if build_meta_file.exists() and not in_git_repo: + print("Build metadata file already exists (not in git repo), keeping it") + return version + + # In git repo (editable) or file doesn't exist, create/update it with open(build_meta_file, "w") as f: f.write('"""Build metadata for flashinfer-jit-cache package."""\n') f.write(f'__version__ = "{version}"\n') diff --git a/pyproject.toml b/pyproject.toml index f635e0cdc7..b3beb2be76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,8 +61,7 @@ exclude = ["flashinfer-jit-cache*", "flashinfer-cubin*"] ] "flashinfer.data" = [ "csrc/**", - "include/**", - "version.txt" + "include/**" ] "flashinfer.data.cutlass" = [ "include/**", From c60cedff3445178e3cbde6106ae6f070a56f1378 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 6 Oct 2025 04:48:49 -0400 Subject: [PATCH 28/41] upd --- scripts/task_test_nightly_build.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/task_test_nightly_build.sh b/scripts/task_test_nightly_build.sh index 3ba666545d..540c1912b4 100755 --- a/scripts/task_test_nightly_build.sh +++ b/scripts/task_test_nightly_build.sh @@ -28,7 +28,8 @@ pip install ${DIST_PYTHON_DIR}/*.tar.gz # Verify installation echo "Verifying installation..." -python -m flashinfer show-config +# Run from /tmp to avoid importing local flashinfer/ source directory +(cd /tmp && python -m flashinfer show-config) # Run test shard echo "Running test shard ${TEST_SHARD}..." From 69284ed8713ef156d9b1db9cf47bee6be0e3acee Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 6 Oct 2025 05:11:09 -0400 Subject: [PATCH 29/41] use import-mode=importlib --- pytest.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/pytest.ini b/pytest.ini index 765de5e01a..91154582ca 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,2 +1,3 @@ [pytest] norecursedirs = test_helpers +addopts = --import-mode=importlib From f130e55b1f24e12fc6783f231abfbdaac6cb9c15 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 6 Oct 2025 05:13:11 -0400 Subject: [PATCH 30/41] add unittests without jit --- flashinfer/aot.py | 3 --- flashinfer/jit/cpp_ext.py | 4 ++++ scripts/task_test_nightly_build.sh | 5 +++++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/flashinfer/aot.py b/flashinfer/aot.py index 81fc8e5eb8..5062e7c576 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -525,9 +525,6 @@ def gen_all_modules( # Add cuDNN FMHA module jit_specs.append(gen_cudnn_fmha_module()) - # NOTE(Zihao): just for debugging, remove later - jit_specs = [gen_spdlog_module()] - # dedup names = set() ret: List[JitSpec] = [] diff --git a/flashinfer/jit/cpp_ext.py b/flashinfer/jit/cpp_ext.py index fb0c40c00e..7d4021de19 100644 --- a/flashinfer/jit/cpp_ext.py +++ b/flashinfer/jit/cpp_ext.py @@ -263,6 +263,10 @@ def _get_num_workers() -> Optional[int]: def run_ninja(workdir: Path, ninja_file: Path, verbose: bool) -> None: + if os.environ.get("FLASHINFER_DISABLE_JIT"): + raise RuntimeError( + "JIT compilation is disabled via FLASHINFER_DISABLE_JIT environment variable" + ) workdir.mkdir(parents=True, exist_ok=True) command = [ "ninja", diff --git a/scripts/task_test_nightly_build.sh b/scripts/task_test_nightly_build.sh index 540c1912b4..63de18da85 100755 --- a/scripts/task_test_nightly_build.sh +++ b/scripts/task_test_nightly_build.sh @@ -23,6 +23,11 @@ pip install ${DIST_CUBIN_DIR}/*.whl echo "Installing flashinfer-jit-cache from ${DIST_JIT_CACHE_DIR}..." pip install ${DIST_JIT_CACHE_DIR}/*.whl +# Disable JIT to verify that jit-cache package contains all necessary +# precompiled modules for the test suite to pass without compilation +echo "Disabling JIT compilation to test with precompiled cache only..." +export FLASHINFER_DISABLE_JIT=1 + echo "Installing flashinfer-python from ${DIST_PYTHON_DIR}..." pip install ${DIST_PYTHON_DIR}/*.tar.gz From d3e7b6d7291e3ad66c2d7e2d94371856d743fd15 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 6 Oct 2025 05:43:33 -0400 Subject: [PATCH 31/41] add backoff for download cubin files, and add number of retries --- flashinfer/artifacts.py | 47 +++++++++++++++++++--------------- flashinfer/jit/cubin_loader.py | 11 ++++---- 2 files changed, 32 insertions(+), 26 deletions(-) diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 4d344d02bb..7a446bc997 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -26,7 +26,7 @@ from .jit.core import logger from .jit.cubin_loader import ( FLASHINFER_CUBINS_REPOSITORY, - get_cubin, + download_file, safe_urljoin, FLASHINFER_CUBIN_DIR, ) @@ -125,26 +125,31 @@ def download_artifacts() -> None: # HTTPS connections. session = requests.Session() - with temp_env_var("FLASHINFER_CUBIN_CHECKSUM_DISABLED", "1"): - cubin_files = list(get_cubin_file_list()) - num_threads = int(os.environ.get("FLASHINFER_CUBIN_DOWNLOAD_THREADS", "4")) - with tqdm_logging_redirect( - total=len(cubin_files), desc="Downloading cubins" - ) as pbar: - - def update_pbar_cb(_) -> None: - pbar.update(1) - - with ThreadPoolExecutor(num_threads) as pool: - futures = [] - for name in cubin_files: - fut = pool.submit(get_cubin, name, "", session) - fut.add_done_callback(update_pbar_cb) - futures.append(fut) - - results = [fut.result() for fut in as_completed(futures)] - - all_success = all(results) + cubin_files = list(get_cubin_file_list()) + num_threads = int(os.environ.get("FLASHINFER_CUBIN_DOWNLOAD_THREADS", "4")) + with tqdm_logging_redirect( + total=len(cubin_files), desc="Downloading cubins" + ) as pbar: + + def update_pbar_cb(_) -> None: + pbar.update(1) + + with ThreadPoolExecutor(num_threads) as pool: + futures = [] + for name in cubin_files: + source = safe_urljoin(FLASHINFER_CUBINS_REPOSITORY, name) + local_path = FLASHINFER_CUBIN_DIR / name + # Ensure parent directory exists + local_path.parent.mkdir(parents=True, exist_ok=True) + fut = pool.submit( + download_file, source, str(local_path), session=session + ) + fut.add_done_callback(update_pbar_cb) + futures.append(fut) + + results = [fut.result() for fut in as_completed(futures)] + + all_success = all(results) if not all_success: raise RuntimeError("Failed to download cubins") diff --git a/flashinfer/jit/cubin_loader.py b/flashinfer/jit/cubin_loader.py index 7d20281225..78615e86f9 100644 --- a/flashinfer/jit/cubin_loader.py +++ b/flashinfer/jit/cubin_loader.py @@ -44,7 +44,7 @@ def safe_urljoin(base, path): def download_file( source: str, local_path: str, - retries: int = 3, + retries: int = 4, delay: int = 5, timeout: int = 10, lock_timeout: int = 30, @@ -57,7 +57,7 @@ def download_file( - source (str): The URL or local file path of the file to download. - local_path (str): The local file path to save the downloaded/copied file. - retries (int): Number of retry attempts for URL downloads (default: 3). - - delay (int): Delay in seconds between retries (default: 5). + - delay (int): Initial delay in seconds for exponential backoff (default: 5). - timeout (int): Timeout for the HTTP request in seconds (default: 10). - lock_timeout (int): Timeout in seconds for the file lock (default: 30). @@ -87,7 +87,7 @@ def download_file( logger.error(f"Failed to copy local file: {e}") return False - # Handle URL downloads + # Handle URL downloads with exponential backoff for attempt in range(1, retries + 1): try: response = session.get(source, timeout=timeout) @@ -107,8 +107,9 @@ def download_file( ) if attempt < retries: - logger.info(f"Retrying in {delay} seconds...") - time.sleep(delay) + backoff_delay = delay * (2 ** (attempt - 1)) + logger.info(f"Retrying in {backoff_delay} seconds...") + time.sleep(backoff_delay) else: logger.error("Max retries reached. Download failed.") return False From 06ffe13bd5691ea3f28d4a3533dd8834d2647e5a Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 6 Oct 2025 21:22:23 -0400 Subject: [PATCH 32/41] bugfix: turned off verbose upd --- flashinfer/jit/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashinfer/jit/core.py b/flashinfer/jit/core.py index e01e5de460..3bc7465241 100644 --- a/flashinfer/jit/core.py +++ b/flashinfer/jit/core.py @@ -294,7 +294,7 @@ def gen_jit_spec( cflags += ["-O3"] # useful for ncu - if bool(os.environ.get("FLASHINFER_JIT_LINEINFO", "0")): + if os.environ.get("FLASHINFER_JIT_LINEINFO", "0") == "1": cuda_cflags += ["-lineinfo"] if extra_cflags is not None: From 99a6f17f769cf91ff5419c8f325ced88cea922bb Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 7 Oct 2025 06:29:32 -0400 Subject: [PATCH 33/41] upd --- flashinfer/jit/__init__.py | 1 + flashinfer/jit/core.py | 25 ++++ flashinfer/jit/cpp_ext.py | 4 - scripts/task_jit_run_tests_part1.sh | 10 +- scripts/task_jit_run_tests_part2.sh | 16 +-- scripts/task_jit_run_tests_part3.sh | 4 +- scripts/task_jit_run_tests_part4.sh | 9 +- scripts/task_jit_run_tests_part5.sh | 4 +- tests/attention/test_alibi.py | 42 ++++--- tests/attention/test_attention_sink.py | 39 +++--- tests/attention/test_batch_attention.py | 38 +++--- tests/attention/test_batch_decode_kernels.py | 54 +++++---- tests/attention/test_batch_invariant_fa2.py | 42 ++++--- tests/attention/test_batch_prefill_kernels.py | 36 +++--- tests/attention/test_deepseek_mla.py | 111 +++++++++--------- tests/attention/test_logits_cap.py | 41 ++++--- tests/attention/test_non_contiguous_decode.py | 42 ++++--- .../attention/test_non_contiguous_prefill.py | 28 +++-- tests/attention/test_shared_prefix_kernels.py | 42 ++++--- tests/attention/test_sliding_window.py | 42 ++++--- tests/attention/test_tensor_cores_decode.py | 42 ++++--- tests/conftest.py | 70 ++++++++++- tests/gemm/test_group_gemm.py | 12 +- tests/utils/test_activation.py | 20 ++-- tests/utils/test_block_sparse.py | 42 ++++--- tests/utils/test_pod_kernels.py | 78 ++++++------ 26 files changed, 529 insertions(+), 365 deletions(-) diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index 4b7bc8f856..314dee1eb3 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -61,6 +61,7 @@ from .core import build_jit_specs as build_jit_specs from .core import clear_cache_dir as clear_cache_dir from .core import gen_jit_spec as gen_jit_spec +from .core import MissingJITCacheError as MissingJITCacheError from .core import sm90a_nvcc_flags as sm90a_nvcc_flags from .core import sm100a_nvcc_flags as sm100a_nvcc_flags from .core import sm100f_nvcc_flags as sm100f_nvcc_flags diff --git a/flashinfer/jit/core.py b/flashinfer/jit/core.py index 3bc7465241..1f987b6036 100644 --- a/flashinfer/jit/core.py +++ b/flashinfer/jit/core.py @@ -18,6 +18,24 @@ os.makedirs(jit_env.FLASHINFER_CSRC_DIR, exist_ok=True) +class MissingJITCacheError(RuntimeError): + """ + Exception raised when JIT compilation is disabled and the JIT cache + does not contain the required precompiled module. + + This error indicates that a module needs to be added to the JIT cache + build configuration. + + Attributes: + spec: JitSpec of the missing module + message: Error message + """ + + def __init__(self, message: str, spec: Optional["JitSpec"] = None): + self.spec = spec + super().__init__(message) + + class FlashInferJITLogger(logging.Logger): def __init__(self, name): super().__init__(name) @@ -228,6 +246,13 @@ def is_ninja_generated(self) -> bool: return self.ninja_path.exists() def build(self, verbose: bool, need_lock: bool = True) -> None: + if os.environ.get("FLASHINFER_DISABLE_JIT"): + raise MissingJITCacheError( + "JIT compilation is disabled via FLASHINFER_DISABLE_JIT environment variable, " + "but the required module is not found in the JIT cache. " + "Please add the missing module to the JIT cache build configuration.", + spec=self, + ) lock = ( FileLock(self.lock_path, thread_local=False) if need_lock else nullcontext() ) diff --git a/flashinfer/jit/cpp_ext.py b/flashinfer/jit/cpp_ext.py index 7d4021de19..fb0c40c00e 100644 --- a/flashinfer/jit/cpp_ext.py +++ b/flashinfer/jit/cpp_ext.py @@ -263,10 +263,6 @@ def _get_num_workers() -> Optional[int]: def run_ninja(workdir: Path, ninja_file: Path, verbose: bool) -> None: - if os.environ.get("FLASHINFER_DISABLE_JIT"): - raise RuntimeError( - "JIT compilation is disabled via FLASHINFER_DISABLE_JIT environment variable" - ) workdir.mkdir(parents=True, exist_ok=True) command = [ "ninja", diff --git a/scripts/task_jit_run_tests_part1.sh b/scripts/task_jit_run_tests_part1.sh index ff7ac2663e..f564d09f36 100755 --- a/scripts/task_jit_run_tests_part1.sh +++ b/scripts/task_jit_run_tests_part1.sh @@ -10,9 +10,11 @@ if [ "$SKIP_INSTALL" = "0" ]; then pip install -e . -v fi +# Run all tests in a single pytest session for better coverage reporting +pytest -s \ + tests/attention/test_logits_cap.py \ + tests/attention/test_sliding_window.py \ + tests/attention/test_tensor_cores_decode.py \ + tests/attention/test_batch_decode_kernels.py # pytest -s tests/gemm/test_group_gemm.py -pytest -s tests/attention/test_logits_cap.py -pytest -s tests/attention/test_sliding_window.py -pytest -s tests/attention/test_tensor_cores_decode.py -pytest -s tests/attention/test_batch_decode_kernels.py # pytest -s tests/attention/test_alibi.py diff --git a/scripts/task_jit_run_tests_part2.sh b/scripts/task_jit_run_tests_part2.sh index 9b93133f5b..b1ca0e18bb 100755 --- a/scripts/task_jit_run_tests_part2.sh +++ b/scripts/task_jit_run_tests_part2.sh @@ -10,10 +10,12 @@ if [ "$SKIP_INSTALL" = "0" ]; then pip install -e . -v fi -pytest -s tests/utils/test_block_sparse.py -pytest -s tests/utils/test_jit_example.py -pytest -s tests/utils/test_jit_warmup.py -pytest -s tests/utils/test_norm.py -pytest -s tests/attention/test_rope.py -pytest -s tests/attention/test_mla_page.py -pytest -s tests/utils/test_quantization.py +# Run all tests in a single pytest session for better coverage reporting +pytest -s \ + tests/utils/test_block_sparse.py \ + tests/utils/test_jit_example.py \ + tests/utils/test_jit_warmup.py \ + tests/utils/test_norm.py \ + tests/attention/test_rope.py \ + tests/attention/test_mla_page.py \ + tests/utils/test_quantization.py diff --git a/scripts/task_jit_run_tests_part3.sh b/scripts/task_jit_run_tests_part3.sh index da82f5af1d..bba83614b6 100755 --- a/scripts/task_jit_run_tests_part3.sh +++ b/scripts/task_jit_run_tests_part3.sh @@ -10,4 +10,6 @@ if [ "$SKIP_INSTALL" = "0" ]; then pip install -e . -v fi -pytest -s tests/utils/test_sampling.py +# Run all tests in a single pytest session for better coverage reporting +pytest -s \ + tests/utils/test_sampling.py diff --git a/scripts/task_jit_run_tests_part4.sh b/scripts/task_jit_run_tests_part4.sh index 22e21c1d21..14377641d8 100755 --- a/scripts/task_jit_run_tests_part4.sh +++ b/scripts/task_jit_run_tests_part4.sh @@ -11,8 +11,11 @@ if [ "$SKIP_INSTALL" = "0" ]; then fi export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True # avoid memory fragmentation -pytest -s tests/attention/test_deepseek_mla.py -pytest -s tests/gemm/test_group_gemm.py -pytest -s tests/attention/test_batch_prefill_kernels.py + +# Run all tests in a single pytest session for better coverage reporting +pytest -s \ + tests/attention/test_deepseek_mla.py \ + tests/gemm/test_group_gemm.py \ + tests/attention/test_batch_prefill_kernels.py # NOTE(Zihao): need to fix tile size on KV dimension for head_dim=256 on small shared memory architecture (sm89) # pytest -s tests/attention/test_batch_attention.py diff --git a/scripts/task_jit_run_tests_part5.sh b/scripts/task_jit_run_tests_part5.sh index 5606673aef..81a920a1db 100755 --- a/scripts/task_jit_run_tests_part5.sh +++ b/scripts/task_jit_run_tests_part5.sh @@ -10,4 +10,6 @@ if [ "$SKIP_INSTALL" = "0" ]; then pip install -e . -v fi -pytest -s tests/utils/test_logits_processor.py +# Run all tests in a single pytest session for better coverage reporting +pytest -s \ + tests/utils/test_logits_processor.py diff --git a/tests/attention/test_alibi.py b/tests/attention/test_alibi.py index 417be942b7..9f891b3571 100644 --- a/tests/attention/test_alibi.py +++ b/tests/attention/test_alibi.py @@ -14,6 +14,8 @@ limitations under the License. """ +import importlib.util + import pytest import torch from tests.test_helpers.alibi_reference import alibi_attention @@ -27,26 +29,28 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - flashinfer.jit.build_jit_specs( - gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128, 256], # head_dims - [0, 2], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps + # Skip warmup if flashinfer_jit_cache package is installed + if importlib.util.find_spec("flashinfer_jit_cache") is None: + flashinfer.jit.build_jit_specs( + gen_decode_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128, 256], # head_dims + [0, 2], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + ) + + gen_prefill_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128, 256], # head_dims + [0, 2], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # use_fp16_qk_reductions + ), + verbose=False, ) - + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128, 256], # head_dims - [0, 2], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions - ), - verbose=False, - ) yield diff --git a/tests/attention/test_attention_sink.py b/tests/attention/test_attention_sink.py index ab3ddc6c4b..95ad831060 100644 --- a/tests/attention/test_attention_sink.py +++ b/tests/attention/test_attention_sink.py @@ -14,6 +14,7 @@ limitations under the License. """ +import importlib.util import math import pytest @@ -29,26 +30,28 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - jit_specs = [] - for dtype in [torch.float16, torch.bfloat16]: - for backend in ["fa2", "fa3"]: - for use_swa in [True, False]: - for head_dim in [128]: - jit_specs.append( - gen_batch_prefill_attention_sink_module( - backend=backend, - dtype_q=dtype, - dtype_kv=dtype, - dtype_o=dtype, - dtype_idx=torch.int32, - head_dim_qk=head_dim, - head_dim_vo=head_dim, - pos_encoding_mode=0, - use_sliding_window=use_swa, + # Skip warmup if flashinfer_jit_cache package is installed + if importlib.util.find_spec("flashinfer_jit_cache") is None: + jit_specs = [] + for dtype in [torch.float16, torch.bfloat16]: + for backend in ["fa2", "fa3"]: + for use_swa in [True, False]: + for head_dim in [128]: + jit_specs.append( + gen_batch_prefill_attention_sink_module( + backend=backend, + dtype_q=dtype, + dtype_kv=dtype, + dtype_o=dtype, + dtype_idx=torch.int32, + head_dim_qk=head_dim, + head_dim_vo=head_dim, + pos_encoding_mode=0, + use_sliding_window=use_swa, + ) ) - ) - flashinfer.jit.build_jit_specs(jit_specs) + flashinfer.jit.build_jit_specs(jit_specs) yield diff --git a/tests/attention/test_batch_attention.py b/tests/attention/test_batch_attention.py index 315c11d04d..02d82b2da6 100644 --- a/tests/attention/test_batch_attention.py +++ b/tests/attention/test_batch_attention.py @@ -14,6 +14,8 @@ limitations under the License. """ +import importlib.util + import numpy as np import pytest import torch @@ -28,24 +30,26 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - flashinfer.jit.build_jit_specs( - gen_persistent_batch_attention_modules( - [torch.float16, torch.bfloat16], # q_dtypes - [torch.float16, torch.bfloat16], # kv_dtypes - [64, 128, 256], # head_dims - [False, True], # use_logits_soft_cap + # Skip warmup if flashinfer_jit_cache package is installed + if importlib.util.find_spec("flashinfer_jit_cache") is None: + flashinfer.jit.build_jit_specs( + gen_persistent_batch_attention_modules( + [torch.float16, torch.bfloat16], # q_dtypes + [torch.float16, torch.bfloat16], # kv_dtypes + [64, 128, 256], # head_dims + [False, True], # use_logits_soft_cap + ) + + gen_prefill_attention_modules( + [torch.float16, torch.bfloat16], # q_dtypes + [torch.float16, torch.bfloat16], # kv_dtypes + [64, 128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False, True], # use_logits_soft_caps + [False], # use_fp16_qk_reductions + ), + verbose=False, ) - + gen_prefill_attention_modules( - [torch.float16, torch.bfloat16], # q_dtypes - [torch.float16, torch.bfloat16], # kv_dtypes - [64, 128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False, True], # use_logits_soft_caps - [False], # use_fp16_qk_reductions - ), - verbose=False, - ) # ------------------------- Configuration generation function ----------------------------- # diff --git a/tests/attention/test_batch_decode_kernels.py b/tests/attention/test_batch_decode_kernels.py index cd04c273c3..9504d4c156 100644 --- a/tests/attention/test_batch_decode_kernels.py +++ b/tests/attention/test_batch_decode_kernels.py @@ -14,6 +14,8 @@ limitations under the License. """ +import importlib.util + import pytest import torch from tests.test_helpers.jit_utils import ( @@ -26,32 +28,34 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - flashinfer.jit.build_jit_specs( - gen_decode_attention_modules( - [torch.float16], # q_dtypes - [ - torch.float16, - torch.float8_e4m3fn, - ], # kv_dtypes - [128, 256], # head_dims - [0, 1], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps + # Skip warmup if flashinfer_jit_cache package is installed + if importlib.util.find_spec("flashinfer_jit_cache") is None: + flashinfer.jit.build_jit_specs( + gen_decode_attention_modules( + [torch.float16], # q_dtypes + [ + torch.float16, + torch.float8_e4m3fn, + ], # kv_dtypes + [128, 256], # head_dims + [0, 1], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + ) + + gen_prefill_attention_modules( + [torch.float16], # q_dtypes + [ + torch.float16, + torch.float8_e4m3fn, + ], # kv_dtypes + [128, 256], # head_dims + [0, 1], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # use_fp16_qk_reductions + ), + verbose=False, ) - + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [ - torch.float16, - torch.float8_e4m3fn, - ], # kv_dtypes - [128, 256], # head_dims - [0, 1], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions - ), - verbose=False, - ) yield diff --git a/tests/attention/test_batch_invariant_fa2.py b/tests/attention/test_batch_invariant_fa2.py index 39e7102349..3a1ee518df 100644 --- a/tests/attention/test_batch_invariant_fa2.py +++ b/tests/attention/test_batch_invariant_fa2.py @@ -14,6 +14,8 @@ limitations under the License. """ +import importlib.util + import pytest import torch from tests.test_helpers.jit_utils import ( @@ -26,26 +28,28 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - flashinfer.jit.build_jit_specs( - gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0, 1], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps + # Skip warmup if flashinfer_jit_cache package is installed + if importlib.util.find_spec("flashinfer_jit_cache") is None: + flashinfer.jit.build_jit_specs( + gen_decode_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0, 1], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + ) + + gen_prefill_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0, 1], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # use_fp16_qk_reductions + ), + verbose=False, ) - + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0, 1], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions - ), - verbose=False, - ) yield diff --git a/tests/attention/test_batch_prefill_kernels.py b/tests/attention/test_batch_prefill_kernels.py index 8c89ee94d0..78924bc701 100644 --- a/tests/attention/test_batch_prefill_kernels.py +++ b/tests/attention/test_batch_prefill_kernels.py @@ -14,6 +14,8 @@ limitations under the License. """ +import importlib.util + import numpy import pytest import torch @@ -24,22 +26,24 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - flashinfer.jit.build_jit_specs( - gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [ - torch.float16, - torch.float8_e4m3fn, - torch.float8_e5m2, - ], # kv_dtypes - [128, 256], # head_dims - [0, 1], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions - ), - verbose=False, - ) + # Skip warmup if flashinfer_jit_cache package is installed + if importlib.util.find_spec("flashinfer_jit_cache") is None: + flashinfer.jit.build_jit_specs( + gen_prefill_attention_modules( + [torch.float16], # q_dtypes + [ + torch.float16, + torch.float8_e4m3fn, + torch.float8_e5m2, + ], # kv_dtypes + [128, 256], # head_dims + [0, 1], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # use_fp16_qk_reductions + ), + verbose=False, + ) yield diff --git a/tests/attention/test_deepseek_mla.py b/tests/attention/test_deepseek_mla.py index 85cafc2d86..5a431e7f13 100644 --- a/tests/attention/test_deepseek_mla.py +++ b/tests/attention/test_deepseek_mla.py @@ -14,6 +14,7 @@ limitations under the License. """ +import importlib.util import math import pytest @@ -37,64 +38,66 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): try: - modules = [] - for backend in ["fa2", "fa3"]: - if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")): - continue - - modules.append( - gen_single_prefill_module( - backend, - torch.float16, - torch.float16, - torch.float16, - 192, - 128, - 0, - False, - False, - False, + # Skip warmup if flashinfer_jit_cache package is installed + if importlib.util.find_spec("flashinfer_jit_cache") is None: + modules = [] + for backend in ["fa2", "fa3"]: + if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")): + continue + + modules.append( + gen_single_prefill_module( + backend, + torch.float16, + torch.float16, + torch.float16, + 192, + 128, + 0, + False, + False, + False, + ) ) - ) - - for backend in ["fa2", "fa3"]: - if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")): - continue - - modules.append( - gen_batch_prefill_module( - backend, - torch.float16, - torch.float16, - torch.float16, - torch.int32, - 192, - 128, - 0, - False, - False, - False, + + for backend in ["fa2", "fa3"]: + if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")): + continue + + modules.append( + gen_batch_prefill_module( + backend, + torch.float16, + torch.float16, + torch.float16, + torch.int32, + 192, + 128, + 0, + False, + False, + False, + ) ) - ) - - for backend in ["fa2", "fa3"]: - if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")): - continue - - modules.append( - gen_batch_mla_module( - backend, - torch.float16, - torch.float16, - torch.float16, - torch.int32, - 512, - 64, - False, + + for backend in ["fa2", "fa3"]: + if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")): + continue + + modules.append( + gen_batch_mla_module( + backend, + torch.float16, + torch.float16, + torch.float16, + torch.int32, + 512, + 64, + False, + ) ) - ) - build_jit_specs(modules, verbose=False) + build_jit_specs(modules, verbose=False) except Exception as e: # abort the test session if warmup fails pytest.exit(str(e)) diff --git a/tests/attention/test_logits_cap.py b/tests/attention/test_logits_cap.py index 5059f3764e..aef2bc806f 100644 --- a/tests/attention/test_logits_cap.py +++ b/tests/attention/test_logits_cap.py @@ -14,6 +14,7 @@ limitations under the License. """ +import importlib.util import math import pytest @@ -28,26 +29,28 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - flashinfer.jit.build_jit_specs( - gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False, True], # use_logits_soft_caps + # Skip warmup if flashinfer_jit_cache package is installed + if importlib.util.find_spec("flashinfer_jit_cache") is None: + flashinfer.jit.build_jit_specs( + gen_decode_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False, True], # use_logits_soft_caps + ) + + gen_prefill_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False, True], # use_logits_soft_caps + [False], # use_fp16_qk_reductions + ), + verbose=False, ) - + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False, True], # use_logits_soft_caps - [False], # use_fp16_qk_reductions - ), - verbose=False, - ) yield diff --git a/tests/attention/test_non_contiguous_decode.py b/tests/attention/test_non_contiguous_decode.py index 198ecd3e9f..86c6d1a829 100644 --- a/tests/attention/test_non_contiguous_decode.py +++ b/tests/attention/test_non_contiguous_decode.py @@ -1,3 +1,5 @@ +import importlib.util + import pytest import torch from tests.test_helpers.jit_utils import ( @@ -10,26 +12,28 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - flashinfer.jit.build_jit_specs( - gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps + # Skip warmup if flashinfer_jit_cache package is installed + if importlib.util.find_spec("flashinfer_jit_cache") is None: + flashinfer.jit.build_jit_specs( + gen_decode_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + ) + + gen_prefill_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # use_fp16_qk_reductions + ), + verbose=False, ) - + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions - ), - verbose=False, - ) yield diff --git a/tests/attention/test_non_contiguous_prefill.py b/tests/attention/test_non_contiguous_prefill.py index 627ef3ca63..89ce1dfa50 100644 --- a/tests/attention/test_non_contiguous_prefill.py +++ b/tests/attention/test_non_contiguous_prefill.py @@ -14,6 +14,8 @@ limitations under the License. """ +import importlib.util + import pytest import torch from tests.test_helpers.jit_utils import gen_prefill_attention_modules @@ -23,18 +25,20 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - flashinfer.jit.build_jit_specs( - gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions - ), - verbose=False, - ) + # Skip warmup if flashinfer_jit_cache package is installed + if importlib.util.find_spec("flashinfer_jit_cache") is None: + flashinfer.jit.build_jit_specs( + gen_prefill_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # use_fp16_qk_reductions + ), + verbose=False, + ) yield diff --git a/tests/attention/test_shared_prefix_kernels.py b/tests/attention/test_shared_prefix_kernels.py index fc25b8afc5..4f1acef410 100644 --- a/tests/attention/test_shared_prefix_kernels.py +++ b/tests/attention/test_shared_prefix_kernels.py @@ -14,6 +14,8 @@ limitations under the License. """ +import importlib.util + import pytest import torch from tests.test_helpers.jit_utils import ( @@ -26,26 +28,28 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - flashinfer.jit.build_jit_specs( - gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps + # Skip warmup if flashinfer_jit_cache package is installed + if importlib.util.find_spec("flashinfer_jit_cache") is None: + flashinfer.jit.build_jit_specs( + gen_decode_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + ) + + gen_prefill_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # use_fp16_qk_reductions + ), + verbose=False, ) - + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions - ), - verbose=False, - ) yield diff --git a/tests/attention/test_sliding_window.py b/tests/attention/test_sliding_window.py index fa22610578..d6b79f106e 100644 --- a/tests/attention/test_sliding_window.py +++ b/tests/attention/test_sliding_window.py @@ -14,6 +14,8 @@ limitations under the License. """ +import importlib.util + import pytest import torch from tests.test_helpers.jit_utils import ( @@ -26,26 +28,28 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - flashinfer.jit.build_jit_specs( - gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0], # pos_encoding_modes - [False, True], # use_sliding_windows - [False], # use_logits_soft_caps + # Skip warmup if flashinfer_jit_cache package is installed + if importlib.util.find_spec("flashinfer_jit_cache") is None: + flashinfer.jit.build_jit_specs( + gen_decode_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0], # pos_encoding_modes + [False, True], # use_sliding_windows + [False], # use_logits_soft_caps + ) + + gen_prefill_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0], # pos_encoding_modes + [False, True], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # use_fp16_qk_reductions + ), + verbose=False, ) - + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0], # pos_encoding_modes - [False, True], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions - ), - verbose=False, - ) yield diff --git a/tests/attention/test_tensor_cores_decode.py b/tests/attention/test_tensor_cores_decode.py index c5bbd84d81..47bb80a1fe 100644 --- a/tests/attention/test_tensor_cores_decode.py +++ b/tests/attention/test_tensor_cores_decode.py @@ -14,6 +14,8 @@ limitations under the License. """ +import importlib.util + import pytest import torch from tests.test_helpers.jit_utils import ( @@ -26,26 +28,28 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - flashinfer.jit.build_jit_specs( - gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0, 1], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps + # Skip warmup if flashinfer_jit_cache package is installed + if importlib.util.find_spec("flashinfer_jit_cache") is None: + flashinfer.jit.build_jit_specs( + gen_decode_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0, 1], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + ) + + gen_prefill_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0, 1], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # use_fp16_qk_reductions + ), + verbose=False, ) - + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0, 1], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions - ), - verbose=False, - ) yield diff --git a/tests/conftest.py b/tests/conftest.py index 3e6694d1d8..64c4305053 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ import os import types -from typing import Any, Dict +from typing import Any, Dict, Set import pytest import torch @@ -8,6 +8,11 @@ from torch.torch_version import __version__ as torch_version import flashinfer +from flashinfer.jit import MissingJITCacheError + +# Global tracking for JIT cache coverage +# Store tuples of (test_name, module_name, spec_info) +_MISSING_JIT_CACHE_MODULES: Set[tuple] = set() TORCH_COMPILE_FNS = [ flashinfer.activation.silu_and_mul, @@ -129,11 +134,72 @@ def is_cuda_oom_error_str(e: str) -> bool: @pytest.hookimpl(tryfirst=True) def pytest_runtest_call(item): - # skip OOM error + # skip OOM error and missing JIT cache errors try: item.runtest() except (torch.cuda.OutOfMemoryError, RuntimeError) as e: if isinstance(e, torch.cuda.OutOfMemoryError) or is_cuda_oom_error_str(str(e)): pytest.skip("Skipping due to OOM") + elif isinstance(e, MissingJITCacheError): + # Record the test that was skipped due to missing JIT cache + test_name = item.nodeid + spec = e.spec + module_name = spec.name if spec else "unknown" + + # Create a dict with module info for reporting + spec_info = None + if spec: + spec_info = { + "name": spec.name, + "sources": [str(s) for s in spec.sources], + "needs_device_linking": spec.needs_device_linking, + "aot_path": str(spec.aot_path), + } + + _MISSING_JIT_CACHE_MODULES.add((test_name, module_name, str(spec_info))) + pytest.skip(f"Skipping due to missing JIT cache for module: {module_name}") else: raise + + +def pytest_terminal_summary(terminalreporter, exitstatus, config): + """Generate JIT cache coverage report at the end of test session""" + if not _MISSING_JIT_CACHE_MODULES: + return # No missing modules + + terminalreporter.section("flashinfer-jit-cache Package Coverage Report") + terminalreporter.write_line("") + terminalreporter.write_line( + "This report shows the coverage of the flashinfer-jit-cache package." + ) + terminalreporter.write_line( + "Tests are skipped when required modules are not found in the installed JIT cache." + ) + terminalreporter.write_line("") + terminalreporter.write_line( + f"âš ī¸ {len(_MISSING_JIT_CACHE_MODULES)} test(s) skipped due to missing JIT cache modules:" + ) + terminalreporter.write_line("") + + # Group by module name + module_to_tests = {} + for test_name, module_name, spec_info in _MISSING_JIT_CACHE_MODULES: + if module_name not in module_to_tests: + module_to_tests[module_name] = {"tests": [], "spec_info": spec_info} + module_to_tests[module_name]["tests"].append(test_name) + + for module_name in sorted(module_to_tests.keys()): + info = module_to_tests[module_name] + terminalreporter.write_line(f"Module: {module_name}") + terminalreporter.write_line(f" Spec: {info['spec_info']}") + terminalreporter.write_line(f" Affected tests ({len(info['tests'])}):") + for test in sorted(info["tests"]): + terminalreporter.write_line(f" - {test}") + terminalreporter.write_line("") + + terminalreporter.write_line( + "These tests require JIT compilation but FLASHINFER_DISABLE_JIT=1 was set." + ) + terminalreporter.write_line( + "To improve coverage, add the missing modules to the flashinfer-jit-cache build configuration." + ) diff --git a/tests/gemm/test_group_gemm.py b/tests/gemm/test_group_gemm.py index 7e0a02c610..6671cc4383 100644 --- a/tests/gemm/test_group_gemm.py +++ b/tests/gemm/test_group_gemm.py @@ -14,6 +14,8 @@ limitations under the License. """ +import importlib.util + import pytest import torch @@ -26,10 +28,12 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - jit_specs = [flashinfer.gemm.gen_gemm_module()] - if is_sm90a_supported(torch.device("cuda:0")): - jit_specs.append(flashinfer.gemm.gen_gemm_sm90_module()) - flashinfer.jit.build_jit_specs(jit_specs, verbose=False) + # Skip warmup if flashinfer_jit_cache package is installed + if importlib.util.find_spec("flashinfer_jit_cache") is None: + jit_specs = [flashinfer.gemm.gen_gemm_module()] + if is_sm90a_supported(torch.device("cuda:0")): + jit_specs.append(flashinfer.gemm.gen_gemm_sm90_module()) + flashinfer.jit.build_jit_specs(jit_specs, verbose=False) yield diff --git a/tests/utils/test_activation.py b/tests/utils/test_activation.py index 3854d7f576..d42f959d13 100644 --- a/tests/utils/test_activation.py +++ b/tests/utils/test_activation.py @@ -14,6 +14,8 @@ limitations under the License. """ +import importlib.util + import pytest import torch @@ -23,14 +25,16 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - flashinfer.jit.build_jit_specs( - [ - flashinfer.activation.gen_act_and_mul_module("silu"), - flashinfer.activation.gen_act_and_mul_module("gelu"), - flashinfer.activation.gen_act_and_mul_module("gelu_tanh"), - ], - verbose=False, - ) + # Skip warmup if flashinfer_jit_cache package is installed + if importlib.util.find_spec("flashinfer_jit_cache") is None: + flashinfer.jit.build_jit_specs( + [ + flashinfer.activation.gen_act_and_mul_module("silu"), + flashinfer.activation.gen_act_and_mul_module("gelu"), + flashinfer.activation.gen_act_and_mul_module("gelu_tanh"), + ], + verbose=False, + ) yield diff --git a/tests/utils/test_block_sparse.py b/tests/utils/test_block_sparse.py index 716e738db7..47e9204e63 100644 --- a/tests/utils/test_block_sparse.py +++ b/tests/utils/test_block_sparse.py @@ -14,6 +14,8 @@ limitations under the License. """ +import importlib.util + import numpy as np import pytest import scipy as sp @@ -28,26 +30,28 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - flashinfer.jit.build_jit_specs( - gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps + # Skip warmup if flashinfer_jit_cache package is installed + if importlib.util.find_spec("flashinfer_jit_cache") is None: + flashinfer.jit.build_jit_specs( + gen_decode_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + ) + + gen_prefill_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # use_fp16_qk_reductions + ), + verbose=False, ) - + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions - ), - verbose=False, - ) yield diff --git a/tests/utils/test_pod_kernels.py b/tests/utils/test_pod_kernels.py index 553e6c3a8c..5cfbd674b2 100644 --- a/tests/utils/test_pod_kernels.py +++ b/tests/utils/test_pod_kernels.py @@ -14,6 +14,8 @@ limitations under the License. """ +import importlib.util + import pytest import torch from tests.test_helpers.jit_utils import ( @@ -27,44 +29,46 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): - flashinfer.jit.build_jit_specs( - gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_fp16_qk_reductions - ) - + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [ - torch.float16, - ], # kv_dtypes - [128], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_cap - [False], # use_fp16_qk_reductions - ) - + [ - gen_pod_module( - torch.float16, # dtype_q - torch.float16, # dtype_kv - torch.float16, # dtype_o - 128, # head_dim - 0, # pos_encoding_mode_p - False, # use_sliding_window_p - False, # use_logits_soft_cap_p - False, # use_fp16_qk_reduction - torch.int32, # dtype_idx - 0, # pos_encoding_mode_d - False, # use_sliding_window_d - False, # use_logits_soft_cap_d + # Skip warmup if flashinfer_jit_cache package is installed + if importlib.util.find_spec("flashinfer_jit_cache") is None: + flashinfer.jit.build_jit_specs( + gen_decode_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_fp16_qk_reductions ) - ], - verbose=False, - ) + + gen_prefill_attention_modules( + [torch.float16], # q_dtypes + [ + torch.float16, + ], # kv_dtypes + [128], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_cap + [False], # use_fp16_qk_reductions + ) + + [ + gen_pod_module( + torch.float16, # dtype_q + torch.float16, # dtype_kv + torch.float16, # dtype_o + 128, # head_dim + 0, # pos_encoding_mode_p + False, # use_sliding_window_p + False, # use_logits_soft_cap_p + False, # use_fp16_qk_reduction + torch.int32, # dtype_idx + 0, # pos_encoding_mode_d + False, # use_sliding_window_d + False, # use_logits_soft_cap_d + ) + ], + verbose=False, + ) yield From 1ce91320e2a550ca9812d6ae91f4869f1b8bd9ee Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 7 Oct 2025 12:33:43 -0400 Subject: [PATCH 34/41] upd --- .github/workflows/nightly-release.yml | 3 +- README.md | 49 +++++++++++++++++-------- docs/installation.rst | 51 +++++++++++++++++++-------- 3 files changed, 73 insertions(+), 30 deletions(-) diff --git a/.github/workflows/nightly-release.yml b/.github/workflows/nightly-release.yml index 8b1945bed8..c6f64ac053 100644 --- a/.github/workflows/nightly-release.yml +++ b/.github/workflows/nightly-release.yml @@ -344,10 +344,11 @@ jobs: with: path: artifacts/ - - name: Collect wheels + - name: Collect wheels and sdist run: | mkdir -p dist find artifacts/ -name "*.whl" -exec cp {} dist/ \; + find artifacts/ -name "*.tar.gz" -exec cp {} dist/ \; ls -lh dist/ - name: Clone wheel index diff --git a/README.md b/README.md index 61963117d7..a6da460f0b 100644 --- a/README.md +++ b/README.md @@ -42,50 +42,71 @@ FlashInfer supports PyTorch, TVM and C++ (header-only) APIs, and can be easily i Using our PyTorch API is the easiest way to get started: -### Install from PIP +### Install from PyPI -FlashInfer is available as a Python package for Linux on PyPI. You can install it with the following command: +FlashInfer is available as a Python package for Linux. Install the core package with: ```bash pip install flashinfer-python ``` +**Package Options:** +- **flashinfer-python**: Core package that compiles/downloads kernels on first use +- **flashinfer-cubin**: Pre-compiled kernel binaries for all supported GPU architectures +- **flashinfer-jit-cache**: Pre-built kernel cache for specific CUDA versions + +**For faster initialization and offline usage**, install the optional packages to have most kernels pre-compiled: +```bash +pip install flashinfer-python flashinfer-cubin flashinfer-jit-cache +``` + +This eliminates compilation and downloading overhead at runtime. + ### Install from Source -Alternatively, build FlashInfer from source: +Build the core package from source: ```bash git clone https://github.com/flashinfer-ai/flashinfer.git --recursive cd flashinfer python -m pip install -v . +``` -# for development & contribution, install in editable mode +**For development**, install in editable mode: +```bash python -m pip install --no-build-isolation -e . -v ``` -`flashinfer-python` is a source-only package and by default it will JIT compile/download kernels on-the-fly. -For fully offline deployment, we also provide two additional packages `flashinfer-jit-cache` and `flashinfer-cubin`, to pre-compile and download cubins ahead-of-time. - -#### flashinfer-cubin +**Build optional packages:** -To build `flashinfer-cubin` package from source: +`flashinfer-cubin`: ```bash cd flashinfer-cubin python -m build --no-isolation --wheel python -m pip install dist/*.whl ``` -#### flashinfer-jit-cache - -To build `flashinfer-jit-cache` package from source: +`flashinfer-jit-cache` (customize `FLASHINFER_CUDA_ARCH_LIST` for your target GPUs): ```bash -export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 12.0a" # user can shrink the list to specific architectures +export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 12.0a" cd flashinfer-jit-cache python -m build --no-isolation --wheel python -m pip install dist/*.whl ``` -For more details, refer to the [Install from Source documentation](https://docs.flashinfer.ai/installation.html#install-from-source). +For more details, see the [Install from Source documentation](https://docs.flashinfer.ai/installation.html#install-from-source). + +### Install Nightly Build + +Nightly builds are available for testing the latest features: + +```bash +# Core and cubin packages +pip install -U --pre flashinfer-python flashinfer-cubin --index-url https://flashinfer.ai/whl/nightly/ + +# JIT cache package (replace cu129 with your CUDA version: cu128, cu129, or cu130) +pip install -U --pre flashinfer-jit-cache --index-url https://flashinfer.ai/whl/nightly/cu129 +``` ### Trying it out diff --git a/docs/installation.rst b/docs/installation.rst index 3700dc0f3f..bc1669fa5b 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -24,23 +24,30 @@ The easiest way to install FlashInfer is via pip. Please note that the package c pip install flashinfer-python +Package Options +""""""""""""""" -.. _install-from-source: +FlashInfer provides three packages: -Install from Source -^^^^^^^^^^^^^^^^^^^ +- **flashinfer-python**: Core package that compiles/downloads kernels on first use +- **flashinfer-cubin**: Pre-compiled kernel binaries for all supported GPU architectures +- **flashinfer-jit-cache**: Pre-built kernel cache for specific CUDA versions -In certain cases, you may want to install FlashInfer from source code to try out the latest features in the main branch, or to customize the library for your specific needs. +**For faster initialization and offline usage**, install the optional packages to have most kernels pre-compiled: + +.. code-block:: bash + + pip install flashinfer-python flashinfer-cubin flashinfer-jit-cache -``flashinfer-python`` is a source-only package and by default it will JIT compile/download kernels on-the-fly. +This eliminates compilation and downloading overhead at runtime. -For fully offline deployment, we also provide two additional packages to pre-compile and download cubins ahead-of-time: -flashinfer-cubin - - Provides pre-compiled CUDA binaries for immediate use without runtime compilation. +.. _install-from-source: -flashinfer-jit-cache - - Pre-compiles kernels for specific CUDA architectures to enable fully offline deployment. +Install from Source +^^^^^^^^^^^^^^^^^^^ + +In certain cases, you may want to install FlashInfer from source code to try out the latest features in the main branch, or to customize the library for your specific needs. You can follow the steps below to install FlashInfer from source code: @@ -63,15 +70,15 @@ You can follow the steps below to install FlashInfer from source code: cd flashinfer python -m pip install -v . - For development & contribution, install in editable mode: + **For development**, install in editable mode: .. code-block:: bash python -m pip install --no-build-isolation -e . -v -4. (Optional) Build additional packages for offline deployment: +4. (Optional) Build optional packages: - To build ``flashinfer-cubin`` package from source: + Build ``flashinfer-cubin``: .. code-block:: bash @@ -79,11 +86,25 @@ You can follow the steps below to install FlashInfer from source code: python -m build --no-isolation --wheel python -m pip install dist/*.whl - To build ``flashinfer-jit-cache`` package from source: + Build ``flashinfer-jit-cache`` (customize ``FLASHINFER_CUDA_ARCH_LIST`` for your target GPUs): .. code-block:: bash - export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 12.0a" # user can shrink the list to specific architectures + export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 12.0a" cd flashinfer-jit-cache python -m build --no-isolation --wheel python -m pip install dist/*.whl + + +Install Nightly Build +^^^^^^^^^^^^^^^^^^^^^^ + +Nightly builds are available for testing the latest features: + +.. code-block:: bash + + # Core and cubin packages + pip install -U --pre flashinfer-python flashinfer-cubin --index-url https://flashinfer.ai/whl/nightly/ + + # JIT cache package (replace cu129 with your CUDA version: cu128, cu129, or cu130) + pip install -U --pre flashinfer-jit-cache --index-url https://flashinfer.ai/whl/nightly/cu129 From b717e2561c4fda95274e24b135cd2acb1177c262 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 7 Oct 2025 13:01:22 -0400 Subject: [PATCH 35/41] upd --- .github/workflows/nightly-release.yml | 12 ++++++------ README.md | 3 ++- docs/installation.rst | 3 ++- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/.github/workflows/nightly-release.yml b/.github/workflows/nightly-release.yml index c6f64ac053..45cd885357 100644 --- a/.github/workflows/nightly-release.yml +++ b/.github/workflows/nightly-release.yml @@ -66,20 +66,20 @@ jobs: python -m pip install --upgrade pip pip install build wheel - - name: Build flashinfer-python sdist + - name: Build flashinfer-python wheel and sdist env: FLASHINFER_DEV_RELEASE_SUFFIX: ${{ needs.setup.outputs.dev_suffix }} run: | echo "Building flashinfer-python with dev suffix: ${FLASHINFER_DEV_RELEASE_SUFFIX}" echo "Git commit: $(git rev-parse HEAD)" - python -m build --sdist + python -m build ls -lh dist/ - name: Upload flashinfer-python artifact uses: actions/upload-artifact@v4 with: - name: flashinfer-python-sdist - path: dist/*.tar.gz + name: flashinfer-python-dist + path: dist/* retention-days: 7 build-flashinfer-cubin: @@ -226,7 +226,7 @@ jobs: - name: Download flashinfer-python artifact uses: actions/download-artifact@v4 with: - name: flashinfer-python-sdist + name: flashinfer-python-dist path: dist-python/ - name: Upload flashinfer-python to release @@ -302,7 +302,7 @@ jobs: - name: Download flashinfer-python artifact uses: actions/download-artifact@v4 with: - name: flashinfer-python-sdist + name: flashinfer-python-dist path: dist-python/ - name: Download flashinfer-cubin artifact diff --git a/README.md b/README.md index a6da460f0b..bce1e9e78b 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,8 @@ pip install flashinfer-python **For faster initialization and offline usage**, install the optional packages to have most kernels pre-compiled: ```bash -pip install flashinfer-python flashinfer-cubin flashinfer-jit-cache +pip install flashinfer-python flashinfer-cubin +pip install flashinfer-jit-cache --index-url https://flashinfer.ai/whl/ ``` This eliminates compilation and downloading overhead at runtime. diff --git a/docs/installation.rst b/docs/installation.rst index bc1669fa5b..35eee3ccd1 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -37,7 +37,8 @@ FlashInfer provides three packages: .. code-block:: bash - pip install flashinfer-python flashinfer-cubin flashinfer-jit-cache + pip install flashinfer-python flashinfer-cubin + pip install flashinfer-jit-cache --index-url https://flashinfer.ai/whl/ This eliminates compilation and downloading overhead at runtime. From 1e5787ac2cdcb5e455c78ede90f841811227c519 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 7 Oct 2025 23:06:34 -0400 Subject: [PATCH 36/41] upd --- .github/workflows/nightly-release.yml | 47 ++++++- scripts/print_jit_cache_summary.py | 86 +++++++++++++ scripts/task_jit_run_tests_part1.sh | 11 +- scripts/task_jit_run_tests_part2.sh | 17 ++- scripts/task_jit_run_tests_part3.sh | 5 +- scripts/task_jit_run_tests_part4.sh | 9 +- scripts/task_jit_run_tests_part5.sh | 5 +- scripts/task_test_nightly_build.sh | 6 + tests/attention/test_alibi.py | 45 +++---- tests/attention/test_attention_sink.py | 43 +++---- tests/attention/test_batch_attention.py | 41 ++++--- tests/attention/test_batch_decode_kernels.py | 57 ++++----- tests/attention/test_batch_invariant_fa2.py | 45 +++---- tests/attention/test_batch_prefill_kernels.py | 39 +++--- tests/attention/test_deepseek_mla.py | 115 +++++++++--------- tests/attention/test_logits_cap.py | 45 +++---- tests/attention/test_non_contiguous_decode.py | 45 +++---- .../attention/test_non_contiguous_prefill.py | 31 ++--- tests/attention/test_shared_prefix_kernels.py | 45 +++---- tests/attention/test_sliding_window.py | 45 +++---- tests/attention/test_tensor_cores_decode.py | 45 +++---- tests/conftest.py | 25 ++++ tests/gemm/test_group_gemm.py | 15 +-- tests/utils/test_activation.py | 23 ++-- tests/utils/test_block_sparse.py | 45 +++---- tests/utils/test_pod_kernels.py | 81 ++++++------ 26 files changed, 595 insertions(+), 421 deletions(-) create mode 100644 scripts/print_jit_cache_summary.py diff --git a/.github/workflows/nightly-release.yml b/.github/workflows/nightly-release.yml index 45cd885357..dee60f1b10 100644 --- a/.github/workflows/nightly-release.yml +++ b/.github/workflows/nightly-release.yml @@ -330,7 +330,52 @@ jobs: CUDA_VISIBLE_DEVICES: 0 run: | DOCKER_IMAGE="flashinfer/flashinfer-ci-${{ steps.docker-tag.outputs.cuda_version }}:${{ steps.docker-tag.outputs.tag }}" - bash ci/bash.sh ${DOCKER_IMAGE} -e TEST_SHARD ${{ matrix.test-shard }} ./scripts/task_test_nightly_build.sh + bash ci/bash.sh ${DOCKER_IMAGE} \ + -e TEST_SHARD=${{ matrix.test-shard }} \ + -e FLASHINFER_JIT_CACHE_REPORT_FILE=/workspace/jit_cache_report_shard${{ matrix.test-shard }}_cuda${{ matrix.cuda }}.json \ + ./scripts/task_test_nightly_build.sh + + - name: Upload JIT cache report + if: always() + uses: actions/upload-artifact@v4 + with: + name: jit-cache-report-shard${{ matrix.test-shard }}-cuda${{ matrix.cuda }} + path: jit_cache_report_shard${{ matrix.test-shard }}_cuda${{ matrix.cuda }}.json + if-no-files-found: ignore + retention-days: 7 + + jit-cache-summary: + needs: test-nightly-build + if: always() + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Download all JIT cache reports + uses: actions/download-artifact@v4 + with: + pattern: jit-cache-report-* + path: jit-reports/ + merge-multiple: true + + - name: Merge and print JIT cache summary + run: | + # Merge all report files into one + mkdir -p merged-reports + cat jit-reports/*.json > merged-reports/all_reports.json 2>/dev/null || echo "No JIT cache reports found" + + # Print summary + if [ -f merged-reports/all_reports.json ] && [ -s merged-reports/all_reports.json ]; then + python scripts/print_jit_cache_summary.py merged-reports/all_reports.json + else + echo "✅ No missing JIT cache modules - all tests passed!" + fi update-wheel-index: needs: [setup, create-release, test-nightly-build] diff --git a/scripts/print_jit_cache_summary.py b/scripts/print_jit_cache_summary.py new file mode 100644 index 0000000000..c719374d08 --- /dev/null +++ b/scripts/print_jit_cache_summary.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +""" +Print aggregated JIT cache coverage summary from multiple pytest runs. + +This script reads the JSON file written by pytest (via conftest.py) and +prints a unified summary of all missing JIT cache modules across all test runs. + +Usage: + python scripts/print_jit_cache_summary.py /path/to/aggregate.json +""" + +import json +import sys +from pathlib import Path +from typing import Dict, Set + + +def print_jit_cache_summary(aggregate_file: str): + """Read and print JIT cache coverage summary from aggregate file""" + aggregate_path = Path(aggregate_file) + + if not aggregate_path.exists(): + print("No JIT cache coverage data found.") + print(f"Expected file: {aggregate_file}") + return + + # Read all entries from the file + missing_modules: Set[tuple] = set() + with open(aggregate_file, "r") as f: + for line in f: + line = line.strip() + if not line: + continue + entry = json.loads(line) + missing_modules.add( + (entry["test_name"], entry["module_name"], entry["spec_info"]) + ) + + if not missing_modules: + print("✅ All tests passed - no missing JIT cache modules!") + return + + # Print summary + print("=" * 80) + print("flashinfer-jit-cache Package Coverage Report") + print("=" * 80) + print() + print("This report shows the coverage of the flashinfer-jit-cache package.") + print( + "Tests are skipped when required modules are not found in the installed JIT cache." + ) + print() + print( + f"âš ī¸ {len(missing_modules)} test(s) skipped due to missing JIT cache modules:" + ) + print() + + # Group by module name + module_to_tests: Dict[str, Dict] = {} + for test_name, module_name, spec_info in missing_modules: + if module_name not in module_to_tests: + module_to_tests[module_name] = {"tests": [], "spec_info": spec_info} + module_to_tests[module_name]["tests"].append(test_name) + + for module_name in sorted(module_to_tests.keys()): + info = module_to_tests[module_name] + print(f"Module: {module_name}") + print(f" Spec: {info['spec_info']}") + print(f" Affected tests ({len(info['tests'])}):") + for test in sorted(info["tests"]): + print(f" - {test}") + print() + + print("These tests require JIT compilation but FLASHINFER_DISABLE_JIT=1 was set.") + print( + "To improve coverage, add the missing modules to the flashinfer-jit-cache build configuration." + ) + print("=" * 80) + + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python scripts/print_jit_cache_summary.py ") + sys.exit(1) + + print_jit_cache_summary(sys.argv[1]) diff --git a/scripts/task_jit_run_tests_part1.sh b/scripts/task_jit_run_tests_part1.sh index f564d09f36..60deeaac7b 100755 --- a/scripts/task_jit_run_tests_part1.sh +++ b/scripts/task_jit_run_tests_part1.sh @@ -10,11 +10,10 @@ if [ "$SKIP_INSTALL" = "0" ]; then pip install -e . -v fi -# Run all tests in a single pytest session for better coverage reporting -pytest -s \ - tests/attention/test_logits_cap.py \ - tests/attention/test_sliding_window.py \ - tests/attention/test_tensor_cores_decode.py \ - tests/attention/test_batch_decode_kernels.py +# Run each test file separately to isolate CUDA memory issues +pytest -s tests/attention/test_logits_cap.py +pytest -s tests/attention/test_sliding_window.py +pytest -s tests/attention/test_tensor_cores_decode.py +pytest -s tests/attention/test_batch_decode_kernels.py # pytest -s tests/gemm/test_group_gemm.py # pytest -s tests/attention/test_alibi.py diff --git a/scripts/task_jit_run_tests_part2.sh b/scripts/task_jit_run_tests_part2.sh index b1ca0e18bb..b4bb6bf17c 100755 --- a/scripts/task_jit_run_tests_part2.sh +++ b/scripts/task_jit_run_tests_part2.sh @@ -10,12 +10,11 @@ if [ "$SKIP_INSTALL" = "0" ]; then pip install -e . -v fi -# Run all tests in a single pytest session for better coverage reporting -pytest -s \ - tests/utils/test_block_sparse.py \ - tests/utils/test_jit_example.py \ - tests/utils/test_jit_warmup.py \ - tests/utils/test_norm.py \ - tests/attention/test_rope.py \ - tests/attention/test_mla_page.py \ - tests/utils/test_quantization.py +# Run each test file separately to isolate CUDA memory issues +pytest -s tests/utils/test_block_sparse.py +pytest -s tests/utils/test_jit_example.py +pytest -s tests/utils/test_jit_warmup.py +pytest -s tests/utils/test_norm.py +pytest -s tests/attention/test_rope.py +pytest -s tests/attention/test_mla_page.py +pytest -s tests/utils/test_quantization.py diff --git a/scripts/task_jit_run_tests_part3.sh b/scripts/task_jit_run_tests_part3.sh index bba83614b6..cb59c7e84f 100755 --- a/scripts/task_jit_run_tests_part3.sh +++ b/scripts/task_jit_run_tests_part3.sh @@ -10,6 +10,5 @@ if [ "$SKIP_INSTALL" = "0" ]; then pip install -e . -v fi -# Run all tests in a single pytest session for better coverage reporting -pytest -s \ - tests/utils/test_sampling.py +# Run each test file separately to isolate CUDA memory issues +pytest -s tests/utils/test_sampling.py diff --git a/scripts/task_jit_run_tests_part4.sh b/scripts/task_jit_run_tests_part4.sh index 14377641d8..c771fa37a7 100755 --- a/scripts/task_jit_run_tests_part4.sh +++ b/scripts/task_jit_run_tests_part4.sh @@ -12,10 +12,9 @@ fi export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True # avoid memory fragmentation -# Run all tests in a single pytest session for better coverage reporting -pytest -s \ - tests/attention/test_deepseek_mla.py \ - tests/gemm/test_group_gemm.py \ - tests/attention/test_batch_prefill_kernels.py +# Run each test file separately to isolate CUDA memory issues +pytest -s tests/attention/test_deepseek_mla.py +pytest -s tests/gemm/test_group_gemm.py +pytest -s tests/attention/test_batch_prefill_kernels.py # NOTE(Zihao): need to fix tile size on KV dimension for head_dim=256 on small shared memory architecture (sm89) # pytest -s tests/attention/test_batch_attention.py diff --git a/scripts/task_jit_run_tests_part5.sh b/scripts/task_jit_run_tests_part5.sh index 81a920a1db..a4ada8334a 100755 --- a/scripts/task_jit_run_tests_part5.sh +++ b/scripts/task_jit_run_tests_part5.sh @@ -10,6 +10,5 @@ if [ "$SKIP_INSTALL" = "0" ]; then pip install -e . -v fi -# Run all tests in a single pytest session for better coverage reporting -pytest -s \ - tests/utils/test_logits_processor.py +# Run each test file separately to isolate CUDA memory issues +pytest -s tests/utils/test_logits_processor.py diff --git a/scripts/task_test_nightly_build.sh b/scripts/task_test_nightly_build.sh index 63de18da85..46f6b76d36 100755 --- a/scripts/task_test_nightly_build.sh +++ b/scripts/task_test_nightly_build.sh @@ -39,4 +39,10 @@ echo "Verifying installation..." # Run test shard echo "Running test shard ${TEST_SHARD}..." export SKIP_INSTALL=1 + +# Pass through JIT cache report file if set +if [ -n "${FLASHINFER_JIT_CACHE_REPORT_FILE}" ]; then + export FLASHINFER_JIT_CACHE_REPORT_FILE +fi + bash scripts/task_jit_run_tests_part${TEST_SHARD}.sh diff --git a/tests/attention/test_alibi.py b/tests/attention/test_alibi.py index 9f891b3571..0971af2765 100644 --- a/tests/attention/test_alibi.py +++ b/tests/attention/test_alibi.py @@ -27,30 +27,31 @@ import flashinfer -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + scope="module", +) def warmup_jit(): - # Skip warmup if flashinfer_jit_cache package is installed - if importlib.util.find_spec("flashinfer_jit_cache") is None: - flashinfer.jit.build_jit_specs( - gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128, 256], # head_dims - [0, 2], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - ) - + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128, 256], # head_dims - [0, 2], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions - ), - verbose=False, + flashinfer.jit.build_jit_specs( + gen_decode_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128, 256], # head_dims + [0, 2], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps ) + + gen_prefill_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128, 256], # head_dims + [0, 2], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # use_fp16_qk_reductions + ), + verbose=False, + ) yield diff --git a/tests/attention/test_attention_sink.py b/tests/attention/test_attention_sink.py index 95ad831060..3807a7d10f 100644 --- a/tests/attention/test_attention_sink.py +++ b/tests/attention/test_attention_sink.py @@ -28,30 +28,31 @@ from flashinfer.utils import is_sm90a_supported -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + scope="module", +) def warmup_jit(): - # Skip warmup if flashinfer_jit_cache package is installed - if importlib.util.find_spec("flashinfer_jit_cache") is None: - jit_specs = [] - for dtype in [torch.float16, torch.bfloat16]: - for backend in ["fa2", "fa3"]: - for use_swa in [True, False]: - for head_dim in [128]: - jit_specs.append( - gen_batch_prefill_attention_sink_module( - backend=backend, - dtype_q=dtype, - dtype_kv=dtype, - dtype_o=dtype, - dtype_idx=torch.int32, - head_dim_qk=head_dim, - head_dim_vo=head_dim, - pos_encoding_mode=0, - use_sliding_window=use_swa, - ) + jit_specs = [] + for dtype in [torch.float16, torch.bfloat16]: + for backend in ["fa2", "fa3"]: + for use_swa in [True, False]: + for head_dim in [128]: + jit_specs.append( + gen_batch_prefill_attention_sink_module( + backend=backend, + dtype_q=dtype, + dtype_kv=dtype, + dtype_o=dtype, + dtype_idx=torch.int32, + head_dim_qk=head_dim, + head_dim_vo=head_dim, + pos_encoding_mode=0, + use_sliding_window=use_swa, ) + ) - flashinfer.jit.build_jit_specs(jit_specs) + flashinfer.jit.build_jit_specs(jit_specs) yield diff --git a/tests/attention/test_batch_attention.py b/tests/attention/test_batch_attention.py index 02d82b2da6..95330cd8e3 100644 --- a/tests/attention/test_batch_attention.py +++ b/tests/attention/test_batch_attention.py @@ -28,28 +28,29 @@ from flashinfer.utils import get_compute_capability -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + scope="module", +) def warmup_jit(): - # Skip warmup if flashinfer_jit_cache package is installed - if importlib.util.find_spec("flashinfer_jit_cache") is None: - flashinfer.jit.build_jit_specs( - gen_persistent_batch_attention_modules( - [torch.float16, torch.bfloat16], # q_dtypes - [torch.float16, torch.bfloat16], # kv_dtypes - [64, 128, 256], # head_dims - [False, True], # use_logits_soft_cap - ) - + gen_prefill_attention_modules( - [torch.float16, torch.bfloat16], # q_dtypes - [torch.float16, torch.bfloat16], # kv_dtypes - [64, 128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False, True], # use_logits_soft_caps - [False], # use_fp16_qk_reductions - ), - verbose=False, + flashinfer.jit.build_jit_specs( + gen_persistent_batch_attention_modules( + [torch.float16, torch.bfloat16], # q_dtypes + [torch.float16, torch.bfloat16], # kv_dtypes + [64, 128, 256], # head_dims + [False, True], # use_logits_soft_cap ) + + gen_prefill_attention_modules( + [torch.float16, torch.bfloat16], # q_dtypes + [torch.float16, torch.bfloat16], # kv_dtypes + [64, 128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False, True], # use_logits_soft_caps + [False], # use_fp16_qk_reductions + ), + verbose=False, + ) # ------------------------- Configuration generation function ----------------------------- # diff --git a/tests/attention/test_batch_decode_kernels.py b/tests/attention/test_batch_decode_kernels.py index 9504d4c156..20c578b102 100644 --- a/tests/attention/test_batch_decode_kernels.py +++ b/tests/attention/test_batch_decode_kernels.py @@ -26,36 +26,37 @@ import flashinfer -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + scope="module", +) def warmup_jit(): - # Skip warmup if flashinfer_jit_cache package is installed - if importlib.util.find_spec("flashinfer_jit_cache") is None: - flashinfer.jit.build_jit_specs( - gen_decode_attention_modules( - [torch.float16], # q_dtypes - [ - torch.float16, - torch.float8_e4m3fn, - ], # kv_dtypes - [128, 256], # head_dims - [0, 1], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - ) - + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [ - torch.float16, - torch.float8_e4m3fn, - ], # kv_dtypes - [128, 256], # head_dims - [0, 1], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions - ), - verbose=False, + flashinfer.jit.build_jit_specs( + gen_decode_attention_modules( + [torch.float16], # q_dtypes + [ + torch.float16, + torch.float8_e4m3fn, + ], # kv_dtypes + [128, 256], # head_dims + [0, 1], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps ) + + gen_prefill_attention_modules( + [torch.float16], # q_dtypes + [ + torch.float16, + torch.float8_e4m3fn, + ], # kv_dtypes + [128, 256], # head_dims + [0, 1], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # use_fp16_qk_reductions + ), + verbose=False, + ) yield diff --git a/tests/attention/test_batch_invariant_fa2.py b/tests/attention/test_batch_invariant_fa2.py index 3a1ee518df..e235fc54cb 100644 --- a/tests/attention/test_batch_invariant_fa2.py +++ b/tests/attention/test_batch_invariant_fa2.py @@ -26,30 +26,31 @@ import flashinfer -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + scope="module", +) def warmup_jit(): - # Skip warmup if flashinfer_jit_cache package is installed - if importlib.util.find_spec("flashinfer_jit_cache") is None: - flashinfer.jit.build_jit_specs( - gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0, 1], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - ) - + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0, 1], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions - ), - verbose=False, + flashinfer.jit.build_jit_specs( + gen_decode_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0, 1], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps ) + + gen_prefill_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0, 1], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # use_fp16_qk_reductions + ), + verbose=False, + ) yield diff --git a/tests/attention/test_batch_prefill_kernels.py b/tests/attention/test_batch_prefill_kernels.py index 78924bc701..7bce4a3fd1 100644 --- a/tests/attention/test_batch_prefill_kernels.py +++ b/tests/attention/test_batch_prefill_kernels.py @@ -24,26 +24,27 @@ import flashinfer -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + scope="module", +) def warmup_jit(): - # Skip warmup if flashinfer_jit_cache package is installed - if importlib.util.find_spec("flashinfer_jit_cache") is None: - flashinfer.jit.build_jit_specs( - gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [ - torch.float16, - torch.float8_e4m3fn, - torch.float8_e5m2, - ], # kv_dtypes - [128, 256], # head_dims - [0, 1], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions - ), - verbose=False, - ) + flashinfer.jit.build_jit_specs( + gen_prefill_attention_modules( + [torch.float16], # q_dtypes + [ + torch.float16, + torch.float8_e4m3fn, + torch.float8_e5m2, + ], # kv_dtypes + [128, 256], # head_dims + [0, 1], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # use_fp16_qk_reductions + ), + verbose=False, + ) yield diff --git a/tests/attention/test_deepseek_mla.py b/tests/attention/test_deepseek_mla.py index 5a431e7f13..5712b8e804 100644 --- a/tests/attention/test_deepseek_mla.py +++ b/tests/attention/test_deepseek_mla.py @@ -35,69 +35,70 @@ ) -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + scope="module", +) def warmup_jit(): try: - # Skip warmup if flashinfer_jit_cache package is installed - if importlib.util.find_spec("flashinfer_jit_cache") is None: - modules = [] - for backend in ["fa2", "fa3"]: - if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")): - continue - - modules.append( - gen_single_prefill_module( - backend, - torch.float16, - torch.float16, - torch.float16, - 192, - 128, - 0, - False, - False, - False, - ) + modules = [] + for backend in ["fa2", "fa3"]: + if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")): + continue + + modules.append( + gen_single_prefill_module( + backend, + torch.float16, + torch.float16, + torch.float16, + 192, + 128, + 0, + False, + False, + False, ) - - for backend in ["fa2", "fa3"]: - if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")): - continue - - modules.append( - gen_batch_prefill_module( - backend, - torch.float16, - torch.float16, - torch.float16, - torch.int32, - 192, - 128, - 0, - False, - False, - False, - ) + ) + + for backend in ["fa2", "fa3"]: + if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")): + continue + + modules.append( + gen_batch_prefill_module( + backend, + torch.float16, + torch.float16, + torch.float16, + torch.int32, + 192, + 128, + 0, + False, + False, + False, ) - - for backend in ["fa2", "fa3"]: - if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")): - continue - - modules.append( - gen_batch_mla_module( - backend, - torch.float16, - torch.float16, - torch.float16, - torch.int32, - 512, - 64, - False, - ) + ) + + for backend in ["fa2", "fa3"]: + if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")): + continue + + modules.append( + gen_batch_mla_module( + backend, + torch.float16, + torch.float16, + torch.float16, + torch.int32, + 512, + 64, + False, ) + ) - build_jit_specs(modules, verbose=False) + build_jit_specs(modules, verbose=False) except Exception as e: # abort the test session if warmup fails pytest.exit(str(e)) diff --git a/tests/attention/test_logits_cap.py b/tests/attention/test_logits_cap.py index aef2bc806f..b220200854 100644 --- a/tests/attention/test_logits_cap.py +++ b/tests/attention/test_logits_cap.py @@ -27,30 +27,31 @@ import flashinfer -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + scope="module", +) def warmup_jit(): - # Skip warmup if flashinfer_jit_cache package is installed - if importlib.util.find_spec("flashinfer_jit_cache") is None: - flashinfer.jit.build_jit_specs( - gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False, True], # use_logits_soft_caps - ) - + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False, True], # use_logits_soft_caps - [False], # use_fp16_qk_reductions - ), - verbose=False, + flashinfer.jit.build_jit_specs( + gen_decode_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False, True], # use_logits_soft_caps ) + + gen_prefill_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False, True], # use_logits_soft_caps + [False], # use_fp16_qk_reductions + ), + verbose=False, + ) yield diff --git a/tests/attention/test_non_contiguous_decode.py b/tests/attention/test_non_contiguous_decode.py index 86c6d1a829..e9fa5c9bdc 100644 --- a/tests/attention/test_non_contiguous_decode.py +++ b/tests/attention/test_non_contiguous_decode.py @@ -10,30 +10,31 @@ import flashinfer -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + scope="module", +) def warmup_jit(): - # Skip warmup if flashinfer_jit_cache package is installed - if importlib.util.find_spec("flashinfer_jit_cache") is None: - flashinfer.jit.build_jit_specs( - gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - ) - + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions - ), - verbose=False, + flashinfer.jit.build_jit_specs( + gen_decode_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps ) + + gen_prefill_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # use_fp16_qk_reductions + ), + verbose=False, + ) yield diff --git a/tests/attention/test_non_contiguous_prefill.py b/tests/attention/test_non_contiguous_prefill.py index 89ce1dfa50..513488707e 100644 --- a/tests/attention/test_non_contiguous_prefill.py +++ b/tests/attention/test_non_contiguous_prefill.py @@ -23,22 +23,23 @@ import flashinfer -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + scope="module", +) def warmup_jit(): - # Skip warmup if flashinfer_jit_cache package is installed - if importlib.util.find_spec("flashinfer_jit_cache") is None: - flashinfer.jit.build_jit_specs( - gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions - ), - verbose=False, - ) + flashinfer.jit.build_jit_specs( + gen_prefill_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # use_fp16_qk_reductions + ), + verbose=False, + ) yield diff --git a/tests/attention/test_shared_prefix_kernels.py b/tests/attention/test_shared_prefix_kernels.py index 4f1acef410..8cdf820b10 100644 --- a/tests/attention/test_shared_prefix_kernels.py +++ b/tests/attention/test_shared_prefix_kernels.py @@ -26,30 +26,31 @@ import flashinfer -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + scope="module", +) def warmup_jit(): - # Skip warmup if flashinfer_jit_cache package is installed - if importlib.util.find_spec("flashinfer_jit_cache") is None: - flashinfer.jit.build_jit_specs( - gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - ) - + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions - ), - verbose=False, + flashinfer.jit.build_jit_specs( + gen_decode_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps ) + + gen_prefill_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # use_fp16_qk_reductions + ), + verbose=False, + ) yield diff --git a/tests/attention/test_sliding_window.py b/tests/attention/test_sliding_window.py index d6b79f106e..93d21d00a0 100644 --- a/tests/attention/test_sliding_window.py +++ b/tests/attention/test_sliding_window.py @@ -26,30 +26,31 @@ import flashinfer -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + scope="module", +) def warmup_jit(): - # Skip warmup if flashinfer_jit_cache package is installed - if importlib.util.find_spec("flashinfer_jit_cache") is None: - flashinfer.jit.build_jit_specs( - gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0], # pos_encoding_modes - [False, True], # use_sliding_windows - [False], # use_logits_soft_caps - ) - + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0], # pos_encoding_modes - [False, True], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions - ), - verbose=False, + flashinfer.jit.build_jit_specs( + gen_decode_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0], # pos_encoding_modes + [False, True], # use_sliding_windows + [False], # use_logits_soft_caps ) + + gen_prefill_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0], # pos_encoding_modes + [False, True], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # use_fp16_qk_reductions + ), + verbose=False, + ) yield diff --git a/tests/attention/test_tensor_cores_decode.py b/tests/attention/test_tensor_cores_decode.py index 47bb80a1fe..7a00fd592c 100644 --- a/tests/attention/test_tensor_cores_decode.py +++ b/tests/attention/test_tensor_cores_decode.py @@ -26,30 +26,31 @@ import flashinfer -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + scope="module", +) def warmup_jit(): - # Skip warmup if flashinfer_jit_cache package is installed - if importlib.util.find_spec("flashinfer_jit_cache") is None: - flashinfer.jit.build_jit_specs( - gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0, 1], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - ) - + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0, 1], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions - ), - verbose=False, + flashinfer.jit.build_jit_specs( + gen_decode_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0, 1], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps ) + + gen_prefill_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [64, 128, 256], # head_dims + [0, 1], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # use_fp16_qk_reductions + ), + verbose=False, + ) yield diff --git a/tests/conftest.py b/tests/conftest.py index 64c4305053..dc81dc0db2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,7 @@ +import json import os import types +from pathlib import Path from typing import Any, Dict, Set import pytest @@ -14,6 +16,9 @@ # Store tuples of (test_name, module_name, spec_info) _MISSING_JIT_CACHE_MODULES: Set[tuple] = set() +# File path for aggregating JIT cache info across multiple pytest runs +_JIT_CACHE_REPORT_FILE = os.environ.get("FLASHINFER_JIT_CACHE_REPORT_FILE", None) + TORCH_COMPILE_FNS = [ flashinfer.activation.silu_and_mul, flashinfer.activation.gelu_and_mul, @@ -167,6 +172,26 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): if not _MISSING_JIT_CACHE_MODULES: return # No missing modules + # If report file is specified, write to file for later aggregation + # Otherwise, print summary directly + if _JIT_CACHE_REPORT_FILE: + from filelock import FileLock + + # Convert set to list for JSON serialization + data = [ + {"test_name": test_name, "module_name": module_name, "spec_info": spec_info} + for test_name, module_name, spec_info in _MISSING_JIT_CACHE_MODULES + ] + + # Use file locking to handle concurrent writes from multiple pytest processes + Path(_JIT_CACHE_REPORT_FILE).parent.mkdir(parents=True, exist_ok=True) + lock_file = _JIT_CACHE_REPORT_FILE + ".lock" + with FileLock(lock_file), open(_JIT_CACHE_REPORT_FILE, "a") as f: + for entry in data: + f.write(json.dumps(entry) + "\n") + return + + # Single pytest run - print summary directly terminalreporter.section("flashinfer-jit-cache Package Coverage Report") terminalreporter.write_line("") terminalreporter.write_line( diff --git a/tests/gemm/test_group_gemm.py b/tests/gemm/test_group_gemm.py index 6671cc4383..ca65a33e84 100644 --- a/tests/gemm/test_group_gemm.py +++ b/tests/gemm/test_group_gemm.py @@ -26,14 +26,15 @@ CUDA_DEVICES = ["cuda:0"] -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + scope="module", +) def warmup_jit(): - # Skip warmup if flashinfer_jit_cache package is installed - if importlib.util.find_spec("flashinfer_jit_cache") is None: - jit_specs = [flashinfer.gemm.gen_gemm_module()] - if is_sm90a_supported(torch.device("cuda:0")): - jit_specs.append(flashinfer.gemm.gen_gemm_sm90_module()) - flashinfer.jit.build_jit_specs(jit_specs, verbose=False) + jit_specs = [flashinfer.gemm.gen_gemm_module()] + if is_sm90a_supported(torch.device("cuda:0")): + jit_specs.append(flashinfer.gemm.gen_gemm_sm90_module()) + flashinfer.jit.build_jit_specs(jit_specs, verbose=False) yield diff --git a/tests/utils/test_activation.py b/tests/utils/test_activation.py index d42f959d13..b46613c6cf 100644 --- a/tests/utils/test_activation.py +++ b/tests/utils/test_activation.py @@ -23,18 +23,19 @@ from flashinfer.utils import get_compute_capability -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + scope="module", +) def warmup_jit(): - # Skip warmup if flashinfer_jit_cache package is installed - if importlib.util.find_spec("flashinfer_jit_cache") is None: - flashinfer.jit.build_jit_specs( - [ - flashinfer.activation.gen_act_and_mul_module("silu"), - flashinfer.activation.gen_act_and_mul_module("gelu"), - flashinfer.activation.gen_act_and_mul_module("gelu_tanh"), - ], - verbose=False, - ) + flashinfer.jit.build_jit_specs( + [ + flashinfer.activation.gen_act_and_mul_module("silu"), + flashinfer.activation.gen_act_and_mul_module("gelu"), + flashinfer.activation.gen_act_and_mul_module("gelu_tanh"), + ], + verbose=False, + ) yield diff --git a/tests/utils/test_block_sparse.py b/tests/utils/test_block_sparse.py index 47e9204e63..46813ef497 100644 --- a/tests/utils/test_block_sparse.py +++ b/tests/utils/test_block_sparse.py @@ -28,30 +28,31 @@ import flashinfer -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + scope="module", +) def warmup_jit(): - # Skip warmup if flashinfer_jit_cache package is installed - if importlib.util.find_spec("flashinfer_jit_cache") is None: - flashinfer.jit.build_jit_specs( - gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - ) - + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions - ), - verbose=False, + flashinfer.jit.build_jit_specs( + gen_decode_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps ) + + gen_prefill_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + [False], # use_fp16_qk_reductions + ), + verbose=False, + ) yield diff --git a/tests/utils/test_pod_kernels.py b/tests/utils/test_pod_kernels.py index 5cfbd674b2..221aeadf7c 100644 --- a/tests/utils/test_pod_kernels.py +++ b/tests/utils/test_pod_kernels.py @@ -27,48 +27,49 @@ from flashinfer.jit.attention import gen_pod_module -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + scope="module", +) def warmup_jit(): - # Skip warmup if flashinfer_jit_cache package is installed - if importlib.util.find_spec("flashinfer_jit_cache") is None: - flashinfer.jit.build_jit_specs( - gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_fp16_qk_reductions - ) - + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [ - torch.float16, - ], # kv_dtypes - [128], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_cap - [False], # use_fp16_qk_reductions - ) - + [ - gen_pod_module( - torch.float16, # dtype_q - torch.float16, # dtype_kv - torch.float16, # dtype_o - 128, # head_dim - 0, # pos_encoding_mode_p - False, # use_sliding_window_p - False, # use_logits_soft_cap_p - False, # use_fp16_qk_reduction - torch.int32, # dtype_idx - 0, # pos_encoding_mode_d - False, # use_sliding_window_d - False, # use_logits_soft_cap_d - ) - ], - verbose=False, + flashinfer.jit.build_jit_specs( + gen_decode_attention_modules( + [torch.float16], # q_dtypes + [torch.float16], # kv_dtypes + [128], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_fp16_qk_reductions ) + + gen_prefill_attention_modules( + [torch.float16], # q_dtypes + [ + torch.float16, + ], # kv_dtypes + [128], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_cap + [False], # use_fp16_qk_reductions + ) + + [ + gen_pod_module( + torch.float16, # dtype_q + torch.float16, # dtype_kv + torch.float16, # dtype_o + 128, # head_dim + 0, # pos_encoding_mode_p + False, # use_sliding_window_p + False, # use_logits_soft_cap_p + False, # use_fp16_qk_reduction + torch.int32, # dtype_idx + 0, # pos_encoding_mode_d + False, # use_sliding_window_d + False, # use_logits_soft_cap_d + ) + ], + verbose=False, + ) yield From dffa1be6f916d5b09a555b5f6cf067a13bc90bea Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Wed, 8 Oct 2025 00:20:53 -0400 Subject: [PATCH 37/41] upd --- README.md | 18 ++++++++++++++++-- docs/installation.rst | 20 ++++++++++++++++++-- flashinfer/__init__.py | 8 ++------ flashinfer/__main__.py | 4 +++- flashinfer/artifacts.py | 8 +++++++- flashinfer/jit/attention/modules.py | 5 ++++- flashinfer/jit/env.py | 2 +- flashinfer/version.py | 23 +++++++++++++++++++++++ pyproject.toml | 3 +++ version.txt | 2 +- 10 files changed, 78 insertions(+), 15 deletions(-) create mode 100644 flashinfer/version.py diff --git a/README.md b/README.md index bce1e9e78b..396b772ce9 100644 --- a/README.md +++ b/README.md @@ -103,12 +103,26 @@ Nightly builds are available for testing the latest features: ```bash # Core and cubin packages -pip install -U --pre flashinfer-python flashinfer-cubin --index-url https://flashinfer.ai/whl/nightly/ - +pip install -U --pre flashinfer-python --extra-index-url https://flashinfer.ai/whl/nightly/ +pip install -U --pre flashinfer-cubin --index-url https://flashinfer.ai/whl/nightly/ # JIT cache package (replace cu129 with your CUDA version: cu128, cu129, or cu130) pip install -U --pre flashinfer-jit-cache --index-url https://flashinfer.ai/whl/nightly/cu129 ``` +### Verify Installation + +After installation, verify that FlashInfer is correctly installed and configured: + +```bash +flashinfer show-config +``` + +This command displays: +- FlashInfer version and installed packages (flashinfer-python, flashinfer-cubin, flashinfer-jit-cache) +- PyTorch and CUDA version information +- Environment variables and artifact paths +- Downloaded cubin status and module compilation status + ### Trying it out Below is a minimal example of using FlashInfer's single-request decode/append/prefill attention kernels: diff --git a/docs/installation.rst b/docs/installation.rst index 35eee3ccd1..fe4d83571b 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -105,7 +105,23 @@ Nightly builds are available for testing the latest features: .. code-block:: bash # Core and cubin packages - pip install -U --pre flashinfer-python flashinfer-cubin --index-url https://flashinfer.ai/whl/nightly/ - + pip install -U --pre flashinfer-python --extra-index-url https://flashinfer.ai/whl/nightly/ + pip install -U --pre flashinfer-cubin --index-url https://flashinfer.ai/whl/nightly/ # JIT cache package (replace cu129 with your CUDA version: cu128, cu129, or cu130) pip install -U --pre flashinfer-jit-cache --index-url https://flashinfer.ai/whl/nightly/cu129 + +Verify Installation +^^^^^^^^^^^^^^^^^^^ + +After installation, verify that FlashInfer is correctly installed and configured: + +.. code-block:: bash + + flashinfer show-config + +This command displays: + +- FlashInfer version and installed packages (flashinfer-python, flashinfer-cubin, flashinfer-jit-cache) +- PyTorch and CUDA version information +- Environment variables and artifact paths +- Downloaded cubin status and module compilation status diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index ec08322649..866a91351c 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -16,12 +16,8 @@ import importlib.util -try: - from ._build_meta import __version__ as __version__ - from ._build_meta import __git_version__ as __git_version__ -except ModuleNotFoundError: - __version__ = "0.0.0+unknown" - __git_version__ = "unknown" +from .version import __version__ as __version__ +from .version import __git_version__ as __git_version__ from . import jit as jit diff --git a/flashinfer/__main__.py b/flashinfer/__main__.py index 701e1638ec..df3a3b5c4f 100644 --- a/flashinfer/__main__.py +++ b/flashinfer/__main__.py @@ -29,7 +29,9 @@ from .jit.env import FLASHINFER_CACHE_DIR, FLASHINFER_CUBIN_DIR from .jit.core import current_compilation_context from .jit.cpp_ext import get_cuda_path, get_cuda_version -from . import __version__ + +# Import __version__ from centralized version module +from .version import __version__ def _download_cubin(): diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 7a446bc997..fe4aeab60b 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -15,6 +15,7 @@ """ from dataclasses import dataclass +import logging import os import re import time @@ -23,7 +24,12 @@ import requests # type: ignore[import-untyped] import shutil -from .jit.core import logger +# Create logger for artifacts module to avoid circular import with jit.core +logger = logging.getLogger("flashinfer.artifacts") +logger.setLevel(os.getenv("FLASHINFER_LOGGING_LEVEL", "INFO").upper()) +if not logger.handlers: + logger.addHandler(logging.StreamHandler()) + from .jit.cubin_loader import ( FLASHINFER_CUBINS_REPOSITORY, download_file, diff --git a/flashinfer/jit/attention/modules.py b/flashinfer/jit/attention/modules.py index c4a3ac9754..3fb4a289d3 100644 --- a/flashinfer/jit/attention/modules.py +++ b/flashinfer/jit/attention/modules.py @@ -20,7 +20,6 @@ import jinja2 import torch -from ...artifacts import ArtifactPath, MetaInfoHash from .. import env as jit_env from ..core import ( JitSpec, @@ -1569,6 +1568,8 @@ def gen_fmha_cutlass_sm100a_module( def gen_trtllm_gen_fmha_module(): + from ...artifacts import ArtifactPath, MetaInfoHash + include_path = f"{ArtifactPath.TRTLLM_GEN_FMHA}/include" header_name = "flashInferMetaInfo" @@ -1687,6 +1688,8 @@ def gen_customize_batch_attention_module( def gen_cudnn_fmha_module(): + from ...artifacts import ArtifactPath + return gen_jit_spec( "fmha_cudnn_gen", [jit_env.FLASHINFER_CSRC_DIR / "cudnn_sdpa_kernel_launcher.cu"], diff --git a/flashinfer/jit/env.py b/flashinfer/jit/env.py index 97330eba1c..ada807356e 100644 --- a/flashinfer/jit/env.py +++ b/flashinfer/jit/env.py @@ -22,7 +22,7 @@ import pathlib import importlib.util from ..compilation_context import CompilationContext -from .. import __version__ as flashinfer_version +from ..version import __version__ as flashinfer_version FLASHINFER_BASE_DIR: pathlib.Path = pathlib.Path( os.getenv("FLASHINFER_WORKSPACE_BASE", pathlib.Path.home().as_posix()) diff --git a/flashinfer/version.py b/flashinfer/version.py new file mode 100644 index 0000000000..95ad245497 --- /dev/null +++ b/flashinfer/version.py @@ -0,0 +1,23 @@ +""" +Copyright (c) 2023 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Centralized version information to avoid circular imports +try: + from ._build_meta import __version__ as __version__ + from ._build_meta import __git_version__ as __git_version__ +except ModuleNotFoundError: + __version__ = "0.0.0+unknown" + __git_version__ = "unknown" diff --git a/pyproject.toml b/pyproject.toml index b3beb2be76..679a8b905d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,9 @@ urls = { Homepage = "https://github.com/flashinfer-ai/flashinfer" } dynamic = ["dependencies", "version"] license-files = ["LICENSE", "licenses/*"] +[project.scripts] +flashinfer = "flashinfer.__main__:cli" + [build-system] requires = ["setuptools>=77", "packaging>=24", "apache-tvm-ffi==0.1.0b15"] build-backend = "build_backend" diff --git a/version.txt b/version.txt index 9e11b32fca..1d0ba9ea18 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.3.1 +0.4.0 From 865f3ae58b16a6637365f4ff521dbca54c9e3e24 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Wed, 8 Oct 2025 04:02:04 -0400 Subject: [PATCH 38/41] upd --- .github/workflows/nightly-release.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/nightly-release.yml b/.github/workflows/nightly-release.yml index dee60f1b10..dcdb1ee049 100644 --- a/.github/workflows/nightly-release.yml +++ b/.github/workflows/nightly-release.yml @@ -331,8 +331,8 @@ jobs: run: | DOCKER_IMAGE="flashinfer/flashinfer-ci-${{ steps.docker-tag.outputs.cuda_version }}:${{ steps.docker-tag.outputs.tag }}" bash ci/bash.sh ${DOCKER_IMAGE} \ - -e TEST_SHARD=${{ matrix.test-shard }} \ - -e FLASHINFER_JIT_CACHE_REPORT_FILE=/workspace/jit_cache_report_shard${{ matrix.test-shard }}_cuda${{ matrix.cuda }}.json \ + -e TEST_SHARD ${{ matrix.test-shard }} \ + -e FLASHINFER_JIT_CACHE_REPORT_FILE /workspace/jit_cache_report_shard${{ matrix.test-shard }}_cuda${{ matrix.cuda }}.json \ ./scripts/task_test_nightly_build.sh - name: Upload JIT cache report From e7d89b8a48a7c29ef9b65fa05fe4d1c6b599e95e Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Wed, 8 Oct 2025 12:45:44 -0400 Subject: [PATCH 39/41] upd --- flashinfer/jit/env.py | 6 ++--- flashinfer/utils.py | 24 +++++++++++++++++++ tests/attention/test_alibi.py | 5 ++-- tests/attention/test_attention_sink.py | 5 ++-- tests/attention/test_batch_attention.py | 6 ++--- tests/attention/test_batch_decode_kernels.py | 5 ++-- tests/attention/test_batch_invariant_fa2.py | 5 ++-- tests/attention/test_batch_prefill_kernels.py | 5 ++-- tests/attention/test_deepseek_mla.py | 4 ++-- tests/attention/test_logits_cap.py | 4 ++-- tests/attention/test_non_contiguous_decode.py | 5 ++-- .../attention/test_non_contiguous_prefill.py | 5 ++-- tests/attention/test_shared_prefix_kernels.py | 5 ++-- tests/attention/test_sliding_window.py | 5 ++-- tests/attention/test_tensor_cores_decode.py | 5 ++-- tests/gemm/test_group_gemm.py | 10 ++++---- tests/utils/test_activation.py | 6 ++--- tests/utils/test_block_sparse.py | 5 ++-- tests/utils/test_pod_kernels.py | 5 ++-- 19 files changed, 65 insertions(+), 55 deletions(-) diff --git a/flashinfer/jit/env.py b/flashinfer/jit/env.py index ada807356e..50d2973184 100644 --- a/flashinfer/jit/env.py +++ b/flashinfer/jit/env.py @@ -20,9 +20,9 @@ import os import pathlib -import importlib.util from ..compilation_context import CompilationContext from ..version import __version__ as flashinfer_version +from ..utils import has_flashinfer_jit_cache, has_flashinfer_cubin FLASHINFER_BASE_DIR: pathlib.Path = pathlib.Path( os.getenv("FLASHINFER_WORKSPACE_BASE", pathlib.Path.home().as_posix()) @@ -40,7 +40,7 @@ def _get_cubin_dir(): 3. Default cache directory """ # First check if flashinfer-cubin package is installed - if importlib.util.find_spec("flashinfer_cubin"): + if has_flashinfer_cubin(): import flashinfer_cubin flashinfer_cubin_version = flashinfer_cubin.__version__ @@ -77,7 +77,7 @@ def _get_aot_dir(): 2. Default fallback to _package_root / "data" / "aot" """ # First check if flashinfer-jit-cache package is installed - if importlib.util.find_spec("flashinfer_jit_cache"): + if has_flashinfer_jit_cache(): import flashinfer_jit_cache flashinfer_jit_cache_version = flashinfer_jit_cache.__version__ diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 448f6d116a..9913fe750a 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -450,6 +450,30 @@ def has_cuda_cudart() -> bool: return importlib.util.find_spec("cuda.cudart") is not None +def has_flashinfer_jit_cache() -> bool: + """ + Check if flashinfer_jit_cache module is available. + + Returns: + True if flashinfer_jit_cache exists, False otherwise + """ + import importlib.util + + return importlib.util.find_spec("flashinfer_jit_cache") is not None + + +def has_flashinfer_cubin() -> bool: + """ + Check if flashinfer_cubin module is available. + + Returns: + True if flashinfer_cubin exists, False otherwise + """ + import importlib.util + + return importlib.util.find_spec("flashinfer_cubin") is not None + + def get_cuda_python_version() -> str: import cuda diff --git a/tests/attention/test_alibi.py b/tests/attention/test_alibi.py index 0971af2765..21114aea7c 100644 --- a/tests/attention/test_alibi.py +++ b/tests/attention/test_alibi.py @@ -14,8 +14,6 @@ limitations under the License. """ -import importlib.util - import pytest import torch from tests.test_helpers.alibi_reference import alibi_attention @@ -25,10 +23,11 @@ ) import flashinfer +from flashinfer.utils import has_flashinfer_jit_cache @pytest.fixture( - autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + autouse=not has_flashinfer_jit_cache(), scope="module", ) def warmup_jit(): diff --git a/tests/attention/test_attention_sink.py b/tests/attention/test_attention_sink.py index 3807a7d10f..aeacae1da1 100644 --- a/tests/attention/test_attention_sink.py +++ b/tests/attention/test_attention_sink.py @@ -14,7 +14,6 @@ limitations under the License. """ -import importlib.util import math import pytest @@ -25,11 +24,11 @@ from flashinfer.jit.utils import filename_safe_dtype_map from flashinfer.jit.attention import gen_batch_prefill_attention_sink_module from flashinfer.jit.attention.variants import attention_sink_decl -from flashinfer.utils import is_sm90a_supported +from flashinfer.utils import has_flashinfer_jit_cache, is_sm90a_supported @pytest.fixture( - autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + autouse=not has_flashinfer_jit_cache(), scope="module", ) def warmup_jit(): diff --git a/tests/attention/test_batch_attention.py b/tests/attention/test_batch_attention.py index 96150c488f..1a0532b479 100644 --- a/tests/attention/test_batch_attention.py +++ b/tests/attention/test_batch_attention.py @@ -14,8 +14,6 @@ limitations under the License. """ -import importlib.util - import numpy as np import pytest import torch @@ -25,11 +23,11 @@ gen_persistent_batch_attention_modules, gen_prefill_attention_modules, ) -from flashinfer.utils import get_compute_capability +from flashinfer.utils import get_compute_capability, has_flashinfer_jit_cache @pytest.fixture( - autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + autouse=not has_flashinfer_jit_cache(), scope="module", ) def warmup_jit(): diff --git a/tests/attention/test_batch_decode_kernels.py b/tests/attention/test_batch_decode_kernels.py index 20c578b102..39e736306a 100644 --- a/tests/attention/test_batch_decode_kernels.py +++ b/tests/attention/test_batch_decode_kernels.py @@ -14,8 +14,6 @@ limitations under the License. """ -import importlib.util - import pytest import torch from tests.test_helpers.jit_utils import ( @@ -24,10 +22,11 @@ ) from functools import partial import flashinfer +from flashinfer.utils import has_flashinfer_jit_cache @pytest.fixture( - autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + autouse=not has_flashinfer_jit_cache(), scope="module", ) def warmup_jit(): diff --git a/tests/attention/test_batch_invariant_fa2.py b/tests/attention/test_batch_invariant_fa2.py index e235fc54cb..ea7abeb2c7 100644 --- a/tests/attention/test_batch_invariant_fa2.py +++ b/tests/attention/test_batch_invariant_fa2.py @@ -14,8 +14,6 @@ limitations under the License. """ -import importlib.util - import pytest import torch from tests.test_helpers.jit_utils import ( @@ -24,10 +22,11 @@ ) import flashinfer +from flashinfer.utils import has_flashinfer_jit_cache @pytest.fixture( - autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + autouse=not has_flashinfer_jit_cache(), scope="module", ) def warmup_jit(): diff --git a/tests/attention/test_batch_prefill_kernels.py b/tests/attention/test_batch_prefill_kernels.py index 7bce4a3fd1..f067a70c62 100644 --- a/tests/attention/test_batch_prefill_kernels.py +++ b/tests/attention/test_batch_prefill_kernels.py @@ -14,18 +14,17 @@ limitations under the License. """ -import importlib.util - import numpy import pytest import torch from tests.test_helpers.jit_utils import gen_prefill_attention_modules import flashinfer +from flashinfer.utils import has_flashinfer_jit_cache @pytest.fixture( - autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + autouse=not has_flashinfer_jit_cache(), scope="module", ) def warmup_jit(): diff --git a/tests/attention/test_deepseek_mla.py b/tests/attention/test_deepseek_mla.py index 5712b8e804..0976c4ff39 100644 --- a/tests/attention/test_deepseek_mla.py +++ b/tests/attention/test_deepseek_mla.py @@ -14,7 +14,6 @@ limitations under the License. """ -import importlib.util import math import pytest @@ -29,6 +28,7 @@ gen_single_prefill_module, ) from flashinfer.utils import ( + has_flashinfer_jit_cache, is_sm90a_supported, is_sm100a_supported, is_sm110a_supported, @@ -36,7 +36,7 @@ @pytest.fixture( - autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + autouse=not has_flashinfer_jit_cache(), scope="module", ) def warmup_jit(): diff --git a/tests/attention/test_logits_cap.py b/tests/attention/test_logits_cap.py index b220200854..14791cac09 100644 --- a/tests/attention/test_logits_cap.py +++ b/tests/attention/test_logits_cap.py @@ -14,7 +14,6 @@ limitations under the License. """ -import importlib.util import math import pytest @@ -25,10 +24,11 @@ ) import flashinfer +from flashinfer.utils import has_flashinfer_jit_cache @pytest.fixture( - autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + autouse=not has_flashinfer_jit_cache(), scope="module", ) def warmup_jit(): diff --git a/tests/attention/test_non_contiguous_decode.py b/tests/attention/test_non_contiguous_decode.py index e9fa5c9bdc..c27ac11e5d 100644 --- a/tests/attention/test_non_contiguous_decode.py +++ b/tests/attention/test_non_contiguous_decode.py @@ -1,5 +1,3 @@ -import importlib.util - import pytest import torch from tests.test_helpers.jit_utils import ( @@ -8,10 +6,11 @@ ) import flashinfer +from flashinfer.utils import has_flashinfer_jit_cache @pytest.fixture( - autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + autouse=not has_flashinfer_jit_cache(), scope="module", ) def warmup_jit(): diff --git a/tests/attention/test_non_contiguous_prefill.py b/tests/attention/test_non_contiguous_prefill.py index 513488707e..96ad4aef05 100644 --- a/tests/attention/test_non_contiguous_prefill.py +++ b/tests/attention/test_non_contiguous_prefill.py @@ -14,17 +14,16 @@ limitations under the License. """ -import importlib.util - import pytest import torch from tests.test_helpers.jit_utils import gen_prefill_attention_modules import flashinfer +from flashinfer.utils import has_flashinfer_jit_cache @pytest.fixture( - autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + autouse=not has_flashinfer_jit_cache(), scope="module", ) def warmup_jit(): diff --git a/tests/attention/test_shared_prefix_kernels.py b/tests/attention/test_shared_prefix_kernels.py index 8cdf820b10..30aee0dc38 100644 --- a/tests/attention/test_shared_prefix_kernels.py +++ b/tests/attention/test_shared_prefix_kernels.py @@ -14,8 +14,6 @@ limitations under the License. """ -import importlib.util - import pytest import torch from tests.test_helpers.jit_utils import ( @@ -24,10 +22,11 @@ ) import flashinfer +from flashinfer.utils import has_flashinfer_jit_cache @pytest.fixture( - autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + autouse=not has_flashinfer_jit_cache(), scope="module", ) def warmup_jit(): diff --git a/tests/attention/test_sliding_window.py b/tests/attention/test_sliding_window.py index 93d21d00a0..e29c984d66 100644 --- a/tests/attention/test_sliding_window.py +++ b/tests/attention/test_sliding_window.py @@ -14,8 +14,6 @@ limitations under the License. """ -import importlib.util - import pytest import torch from tests.test_helpers.jit_utils import ( @@ -24,10 +22,11 @@ ) import flashinfer +from flashinfer.utils import has_flashinfer_jit_cache @pytest.fixture( - autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + autouse=not has_flashinfer_jit_cache(), scope="module", ) def warmup_jit(): diff --git a/tests/attention/test_tensor_cores_decode.py b/tests/attention/test_tensor_cores_decode.py index 7a00fd592c..19db15a640 100644 --- a/tests/attention/test_tensor_cores_decode.py +++ b/tests/attention/test_tensor_cores_decode.py @@ -14,8 +14,6 @@ limitations under the License. """ -import importlib.util - import pytest import torch from tests.test_helpers.jit_utils import ( @@ -24,10 +22,11 @@ ) from functools import partial import flashinfer +from flashinfer.utils import has_flashinfer_jit_cache @pytest.fixture( - autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + autouse=not has_flashinfer_jit_cache(), scope="module", ) def warmup_jit(): diff --git a/tests/gemm/test_group_gemm.py b/tests/gemm/test_group_gemm.py index ca65a33e84..fbdd9e26e4 100644 --- a/tests/gemm/test_group_gemm.py +++ b/tests/gemm/test_group_gemm.py @@ -14,20 +14,22 @@ limitations under the License. """ -import importlib.util - import pytest import torch import flashinfer -from flashinfer.utils import determine_gemm_backend, is_sm90a_supported +from flashinfer.utils import ( + determine_gemm_backend, + has_flashinfer_jit_cache, + is_sm90a_supported, +) DTYPES = [torch.float16] CUDA_DEVICES = ["cuda:0"] @pytest.fixture( - autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + autouse=not has_flashinfer_jit_cache(), scope="module", ) def warmup_jit(): diff --git a/tests/utils/test_activation.py b/tests/utils/test_activation.py index b46613c6cf..3a81681592 100644 --- a/tests/utils/test_activation.py +++ b/tests/utils/test_activation.py @@ -14,17 +14,15 @@ limitations under the License. """ -import importlib.util - import pytest import torch import flashinfer -from flashinfer.utils import get_compute_capability +from flashinfer.utils import get_compute_capability, has_flashinfer_jit_cache @pytest.fixture( - autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + autouse=not has_flashinfer_jit_cache(), scope="module", ) def warmup_jit(): diff --git a/tests/utils/test_block_sparse.py b/tests/utils/test_block_sparse.py index 46813ef497..46052d18a3 100644 --- a/tests/utils/test_block_sparse.py +++ b/tests/utils/test_block_sparse.py @@ -14,8 +14,6 @@ limitations under the License. """ -import importlib.util - import numpy as np import pytest import scipy as sp @@ -26,10 +24,11 @@ ) import flashinfer +from flashinfer.utils import has_flashinfer_jit_cache @pytest.fixture( - autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + autouse=not has_flashinfer_jit_cache(), scope="module", ) def warmup_jit(): diff --git a/tests/utils/test_pod_kernels.py b/tests/utils/test_pod_kernels.py index 221aeadf7c..8900cc1b6c 100644 --- a/tests/utils/test_pod_kernels.py +++ b/tests/utils/test_pod_kernels.py @@ -14,8 +14,6 @@ limitations under the License. """ -import importlib.util - import pytest import torch from tests.test_helpers.jit_utils import ( @@ -25,10 +23,11 @@ import flashinfer from flashinfer.jit.attention import gen_pod_module +from flashinfer.utils import has_flashinfer_jit_cache @pytest.fixture( - autouse=importlib.util.find_spec("flashinfer_jit_cache") is None, + autouse=not has_flashinfer_jit_cache(), scope="module", ) def warmup_jit(): From 5ee14dc801c064e6b41d32e0e1f435cf2d38f0c9 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Wed, 8 Oct 2025 12:46:10 -0400 Subject: [PATCH 40/41] upd --- .github/workflows/nightly-release.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/nightly-release.yml b/.github/workflows/nightly-release.yml index dcdb1ee049..bf5bfdc5be 100644 --- a/.github/workflows/nightly-release.yml +++ b/.github/workflows/nightly-release.yml @@ -10,8 +10,6 @@ on: description: 'Date suffix for dev version (YYYYMMDD, leave empty for today)' required: false type: string - pull_request: - # TODO: Remove this before merging - only for debugging this PR jobs: setup: From 26a8f1be2b470c7b57dd0696e0a1cecc3f6fc551 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Wed, 8 Oct 2025 12:57:58 -0400 Subject: [PATCH 41/41] address circular dependency --- flashinfer/jit/env.py | 26 +++++++++++++++++++++++++- flashinfer/utils.py | 27 +++++---------------------- 2 files changed, 30 insertions(+), 23 deletions(-) diff --git a/flashinfer/jit/env.py b/flashinfer/jit/env.py index 50d2973184..4f50552d71 100644 --- a/flashinfer/jit/env.py +++ b/flashinfer/jit/env.py @@ -22,7 +22,31 @@ import pathlib from ..compilation_context import CompilationContext from ..version import __version__ as flashinfer_version -from ..utils import has_flashinfer_jit_cache, has_flashinfer_cubin + + +def has_flashinfer_jit_cache() -> bool: + """ + Check if flashinfer_jit_cache module is available. + + Returns: + True if flashinfer_jit_cache exists, False otherwise + """ + import importlib.util + + return importlib.util.find_spec("flashinfer_jit_cache") is not None + + +def has_flashinfer_cubin() -> bool: + """ + Check if flashinfer_cubin module is available. + + Returns: + True if flashinfer_cubin exists, False otherwise + """ + import importlib.util + + return importlib.util.find_spec("flashinfer_cubin") is not None + FLASHINFER_BASE_DIR: pathlib.Path = pathlib.Path( os.getenv("FLASHINFER_WORKSPACE_BASE", pathlib.Path.home().as_posix()) diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 9913fe750a..d107c88298 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -450,28 +450,11 @@ def has_cuda_cudart() -> bool: return importlib.util.find_spec("cuda.cudart") is not None -def has_flashinfer_jit_cache() -> bool: - """ - Check if flashinfer_jit_cache module is available. - - Returns: - True if flashinfer_jit_cache exists, False otherwise - """ - import importlib.util - - return importlib.util.find_spec("flashinfer_jit_cache") is not None - - -def has_flashinfer_cubin() -> bool: - """ - Check if flashinfer_cubin module is available. - - Returns: - True if flashinfer_cubin exists, False otherwise - """ - import importlib.util - - return importlib.util.find_spec("flashinfer_cubin") is not None +# Re-export from jit.env to avoid circular dependency +from .jit.env import ( + has_flashinfer_jit_cache as has_flashinfer_jit_cache, + has_flashinfer_cubin as has_flashinfer_cubin, +) def get_cuda_python_version() -> str: