diff --git a/.github/workflows/nightly-release.yml b/.github/workflows/nightly-release.yml new file mode 100644 index 0000000000..bf5bfdc5be --- /dev/null +++ b/.github/workflows/nightly-release.yml @@ -0,0 +1,422 @@ +name: Nightly Release + +on: + schedule: + # Run at 00:00 UTC every day + - cron: '0 0 * * *' + workflow_dispatch: + inputs: + date_suffix: + description: 'Date suffix for dev version (YYYYMMDD, leave empty for today)' + required: false + type: string + +jobs: + setup: + runs-on: ubuntu-latest + outputs: + dev_suffix: ${{ steps.set-suffix.outputs.dev_suffix }} + release_tag: ${{ steps.set-suffix.outputs.release_tag }} + version: ${{ steps.set-suffix.outputs.version }} + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set date suffix and release tag + id: set-suffix + run: | + # Read version from version.txt + VERSION=$(cat version.txt | tr -d '[:space:]') + + # Set date suffix + if [ -n "${{ inputs.date_suffix }}" ]; then + DEV_SUFFIX="${{ inputs.date_suffix }}" + else + DEV_SUFFIX=$(date -u +%Y%m%d) + fi + + # Create release tag with version + RELEASE_TAG="nightly-v${VERSION}-${DEV_SUFFIX}" + + echo "version=${VERSION}" >> $GITHUB_OUTPUT + echo "dev_suffix=${DEV_SUFFIX}" >> $GITHUB_OUTPUT + echo "release_tag=${RELEASE_TAG}" >> $GITHUB_OUTPUT + echo "Base version: ${VERSION}" + echo "Using dev suffix: ${DEV_SUFFIX}" + echo "Release tag: ${RELEASE_TAG}" + + build-flashinfer-python: + needs: setup + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + submodules: true + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install build dependencies + run: | + python -m pip install --upgrade pip + pip install build wheel + + - name: Build flashinfer-python wheel and sdist + env: + FLASHINFER_DEV_RELEASE_SUFFIX: ${{ needs.setup.outputs.dev_suffix }} + run: | + echo "Building flashinfer-python with dev suffix: ${FLASHINFER_DEV_RELEASE_SUFFIX}" + echo "Git commit: $(git rev-parse HEAD)" + python -m build + ls -lh dist/ + + - name: Upload flashinfer-python artifact + uses: actions/upload-artifact@v4 + with: + name: flashinfer-python-dist + path: dist/* + retention-days: 7 + + build-flashinfer-cubin: + needs: setup + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + submodules: true + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install build dependencies + run: | + python -m pip install --upgrade pip + pip install build twine wheel + pip install setuptools>=61.0 requests filelock torch tqdm numpy apache-tvm-ffi==0.1.0b15 + + - name: Build flashinfer-cubin wheel + env: + FLASHINFER_DEV_RELEASE_SUFFIX: ${{ needs.setup.outputs.dev_suffix }} + run: | + echo "Building flashinfer-cubin with dev suffix: ${FLASHINFER_DEV_RELEASE_SUFFIX}" + echo "Git commit: $(git rev-parse HEAD)" + cd flashinfer-cubin + rm -rf dist build *.egg-info + python -m build --wheel + ls -lh dist/ + mkdir -p ../dist + cp dist/*.whl ../dist/ + + - name: Upload flashinfer-cubin artifact + uses: actions/upload-artifact@v4 + with: + name: flashinfer-cubin-wheel + path: dist/*.whl + retention-days: 7 + + build-flashinfer-jit-cache: + needs: setup + strategy: + fail-fast: false + matrix: + cuda: ["12.8", "12.9", "13.0"] + arch: ['x86_64', 'aarch64'] + + runs-on: [self-hosted, "${{ matrix.arch == 'aarch64' && 'arm64' || matrix.arch }}"] + + steps: + - name: Display Machine Information + run: | + echo "CPU: $(nproc) cores, $(lscpu | grep 'Model name' | cut -d':' -f2 | xargs)" + echo "RAM: $(free -h | awk '/^Mem:/ {print $7 " available out of " $2}')" + echo "Disk: $(df -h / | awk 'NR==2 {print $4 " available out of " $2}')" + echo "Architecture: $(uname -m)" + + - name: Checkout code + uses: actions/checkout@v4 + with: + submodules: true + + - name: Build wheel in container + env: + DOCKER_IMAGE: ${{ matrix.arch == 'aarch64' && format('pytorch/manylinuxaarch64-builder:cuda{0}', matrix.cuda) || format('pytorch/manylinux2_28-builder:cuda{0}', matrix.cuda) }} + FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda == '12.8' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 12.0a' }} + FLASHINFER_DEV_RELEASE_SUFFIX: ${{ needs.setup.outputs.dev_suffix }} + run: | + # Extract CUDA major and minor versions + CUDA_MAJOR=$(echo "${{ matrix.cuda }}" | cut -d'.' -f1) + CUDA_MINOR=$(echo "${{ matrix.cuda }}" | cut -d'.' -f2) + export CUDA_MAJOR + export CUDA_MINOR + export CUDA_VERSION_SUFFIX="cu${CUDA_MAJOR}${CUDA_MINOR}" + + chown -R $(id -u):$(id -g) ${{ github.workspace }} + mkdir -p ${{ github.workspace }}/ci-cache + chown -R $(id -u):$(id -g) ${{ github.workspace }}/ci-cache + + # Run the build script inside the container with proper mounts + docker run --rm \ + -v ${{ github.workspace }}:/workspace \ + -v ${{ github.workspace }}/ci-cache:/ci-cache \ + -e FLASHINFER_CI_CACHE=/ci-cache \ + -e CUDA_VERSION="${{ matrix.cuda }}" \ + -e CUDA_MAJOR="$CUDA_MAJOR" \ + -e CUDA_MINOR="$CUDA_MINOR" \ + -e CUDA_VERSION_SUFFIX="$CUDA_VERSION_SUFFIX" \ + -e FLASHINFER_DEV_RELEASE_SUFFIX="${FLASHINFER_DEV_RELEASE_SUFFIX}" \ + -e ARCH="${{ matrix.arch }}" \ + -e FLASHINFER_CUDA_ARCH_LIST="${FLASHINFER_CUDA_ARCH_LIST}" \ + --user $(id -u):$(id -g) \ + -w /workspace \ + ${{ env.DOCKER_IMAGE }} \ + bash /workspace/scripts/build_flashinfer_jit_cache_whl.sh + timeout-minutes: 180 + + - name: Display wheel size + run: du -h flashinfer-jit-cache/dist/* + + - name: Create artifact name + id: artifact-name + run: | + CUDA_NO_DOT=$(echo "${{ matrix.cuda }}" | tr -d '.') + echo "name=jit-cache-cu${CUDA_NO_DOT}-${{ matrix.arch }}" >> $GITHUB_OUTPUT + + - name: Upload flashinfer-jit-cache artifact + uses: actions/upload-artifact@v4 + with: + name: ${{ steps.artifact-name.outputs.name }} + path: flashinfer-jit-cache/dist/*.whl + retention-days: 7 + + create-release: + needs: [setup, build-flashinfer-python, build-flashinfer-cubin, build-flashinfer-jit-cache] + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Create GitHub Release (empty first) + env: + GH_TOKEN: ${{ github.token }} + run: | + TAG="${{ needs.setup.outputs.release_tag }}" + + # Delete existing release and tag if they exist + if gh release view "$TAG" &>/dev/null; then + echo "Deleting existing release: $TAG" + gh release delete "$TAG" --yes --cleanup-tag + fi + + # Create new release without assets first + gh release create "$TAG" \ + --title "Nightly Release v${{ needs.setup.outputs.version }}-${{ needs.setup.outputs.dev_suffix }}" \ + --notes "Automated nightly build for version ${{ needs.setup.outputs.version }} (dev${{ needs.setup.outputs.dev_suffix }})" \ + --prerelease + + - name: Download flashinfer-python artifact + uses: actions/download-artifact@v4 + with: + name: flashinfer-python-dist + path: dist-python/ + + - name: Upload flashinfer-python to release + env: + GH_TOKEN: ${{ github.token }} + run: | + gh release upload "${{ needs.setup.outputs.release_tag }}" dist-python/* --clobber + + - name: Download flashinfer-cubin artifact + uses: actions/download-artifact@v4 + with: + name: flashinfer-cubin-wheel + path: dist-cubin/ + + - name: Upload flashinfer-cubin to release + env: + GH_TOKEN: ${{ github.token }} + run: | + gh release upload "${{ needs.setup.outputs.release_tag }}" dist-cubin/* --clobber + + - name: Upload flashinfer-jit-cache wheels to release (one at a time to avoid OOM) + env: + GH_TOKEN: ${{ github.token }} + run: | + # Upload jit-cache wheels one at a time to avoid OOM + # Each wheel can be several GB, so we download, upload, delete, repeat + mkdir -p dist-jit-cache + + for cuda in 128 129 130; do + for arch in x86_64 aarch64; do + ARTIFACT_NAME="jit-cache-cu${cuda}-${arch}" + echo "Processing ${ARTIFACT_NAME}..." + + # Download this specific artifact + gh run download ${{ github.run_id }} -n "${ARTIFACT_NAME}" -D dist-jit-cache/ || { + echo "Warning: Failed to download ${ARTIFACT_NAME}, skipping..." + continue + } + + # Upload to release + if [ -n "$(ls -A dist-jit-cache/)" ]; then + gh release upload "${{ needs.setup.outputs.release_tag }}" dist-jit-cache/* --clobber + echo "āœ… Uploaded ${ARTIFACT_NAME}" + fi + + # Clean up to save disk space before next iteration + rm -rf dist-jit-cache/* + done + done + + test-nightly-build: + needs: [setup, build-flashinfer-python, build-flashinfer-cubin, build-flashinfer-jit-cache] + strategy: + fail-fast: false + matrix: + cuda: ["12.9", "13.0"] + test-shard: [1, 2, 3, 4, 5] + runs-on: [self-hosted, G5, X64] + + steps: + - name: Display Machine Information + run: | + echo "CPU: $(nproc) cores, $(lscpu | grep 'Model name' | cut -d':' -f2 | xargs)" + echo "RAM: $(free -h | awk '/^Mem:/ {print $7 " available out of " $2}')" + echo "Disk: $(df -h / | awk 'NR==2 {print $4 " available out of " $2}')" + echo "Architecture: $(uname -m)" + + - name: Checkout code + uses: actions/checkout@v4 + with: + submodules: true + + - name: Download flashinfer-python artifact + uses: actions/download-artifact@v4 + with: + name: flashinfer-python-dist + path: dist-python/ + + - name: Download flashinfer-cubin artifact + uses: actions/download-artifact@v4 + with: + name: flashinfer-cubin-wheel + path: dist-cubin/ + + - name: Download flashinfer-jit-cache artifact + uses: actions/download-artifact@v4 + with: + name: jit-cache-cu${{ matrix.cuda == '12.9' && '129' || '130' }}-x86_64 + path: dist-jit-cache/ + + - name: Get Docker image tag + id: docker-tag + run: | + CUDA_VERSION="cu${{ matrix.cuda == '12.9' && '129' || '130' }}" + DOCKER_TAG=$(grep "flashinfer/flashinfer-ci-${CUDA_VERSION}" ci/docker-tags.yml | cut -d':' -f2 | tr -d ' ') + echo "cuda_version=${CUDA_VERSION}" >> $GITHUB_OUTPUT + echo "tag=${DOCKER_TAG}" >> $GITHUB_OUTPUT + + - name: Run nightly build tests in Docker (shard ${{ matrix.test-shard }}) + env: + CUDA_VISIBLE_DEVICES: 0 + run: | + DOCKER_IMAGE="flashinfer/flashinfer-ci-${{ steps.docker-tag.outputs.cuda_version }}:${{ steps.docker-tag.outputs.tag }}" + bash ci/bash.sh ${DOCKER_IMAGE} \ + -e TEST_SHARD ${{ matrix.test-shard }} \ + -e FLASHINFER_JIT_CACHE_REPORT_FILE /workspace/jit_cache_report_shard${{ matrix.test-shard }}_cuda${{ matrix.cuda }}.json \ + ./scripts/task_test_nightly_build.sh + + - name: Upload JIT cache report + if: always() + uses: actions/upload-artifact@v4 + with: + name: jit-cache-report-shard${{ matrix.test-shard }}-cuda${{ matrix.cuda }} + path: jit_cache_report_shard${{ matrix.test-shard }}_cuda${{ matrix.cuda }}.json + if-no-files-found: ignore + retention-days: 7 + + jit-cache-summary: + needs: test-nightly-build + if: always() + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Download all JIT cache reports + uses: actions/download-artifact@v4 + with: + pattern: jit-cache-report-* + path: jit-reports/ + merge-multiple: true + + - name: Merge and print JIT cache summary + run: | + # Merge all report files into one + mkdir -p merged-reports + cat jit-reports/*.json > merged-reports/all_reports.json 2>/dev/null || echo "No JIT cache reports found" + + # Print summary + if [ -f merged-reports/all_reports.json ] && [ -s merged-reports/all_reports.json ]; then + python scripts/print_jit_cache_summary.py merged-reports/all_reports.json + else + echo "āœ… No missing JIT cache modules - all tests passed!" + fi + + update-wheel-index: + needs: [setup, create-release, test-nightly-build] + runs-on: ubuntu-latest + steps: + - name: Checkout flashinfer repo + uses: actions/checkout@v4 + + - name: Download all artifacts + uses: actions/download-artifact@v4 + with: + path: artifacts/ + + - name: Collect wheels and sdist + run: | + mkdir -p dist + find artifacts/ -name "*.whl" -exec cp {} dist/ \; + find artifacts/ -name "*.tar.gz" -exec cp {} dist/ \; + ls -lh dist/ + + - name: Clone wheel index + run: git clone https://oauth2:${WHL_TOKEN}@github.com/flashinfer-ai/whl.git flashinfer-whl + env: + WHL_TOKEN: ${{ secrets.WHL_TOKEN }} + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Update wheel index + run: | + python3 scripts/update_whl_index.py \ + --dist-dir dist \ + --output-dir flashinfer-whl \ + --release-tag "${{ needs.setup.outputs.release_tag }}" \ + --nightly + + - name: Push wheel index + run: | + cd flashinfer-whl + git config --local user.name "github-actions[bot]" + git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com" + git add -A + git commit -m "update whl for nightly ${{ needs.setup.outputs.dev_suffix }}" + git push diff --git a/README.md b/README.md index 61963117d7..396b772ce9 100644 --- a/README.md +++ b/README.md @@ -42,50 +42,86 @@ FlashInfer supports PyTorch, TVM and C++ (header-only) APIs, and can be easily i Using our PyTorch API is the easiest way to get started: -### Install from PIP +### Install from PyPI -FlashInfer is available as a Python package for Linux on PyPI. You can install it with the following command: +FlashInfer is available as a Python package for Linux. Install the core package with: ```bash pip install flashinfer-python ``` +**Package Options:** +- **flashinfer-python**: Core package that compiles/downloads kernels on first use +- **flashinfer-cubin**: Pre-compiled kernel binaries for all supported GPU architectures +- **flashinfer-jit-cache**: Pre-built kernel cache for specific CUDA versions + +**For faster initialization and offline usage**, install the optional packages to have most kernels pre-compiled: +```bash +pip install flashinfer-python flashinfer-cubin +pip install flashinfer-jit-cache --index-url https://flashinfer.ai/whl/ +``` + +This eliminates compilation and downloading overhead at runtime. + ### Install from Source -Alternatively, build FlashInfer from source: +Build the core package from source: ```bash git clone https://github.com/flashinfer-ai/flashinfer.git --recursive cd flashinfer python -m pip install -v . +``` -# for development & contribution, install in editable mode +**For development**, install in editable mode: +```bash python -m pip install --no-build-isolation -e . -v ``` -`flashinfer-python` is a source-only package and by default it will JIT compile/download kernels on-the-fly. -For fully offline deployment, we also provide two additional packages `flashinfer-jit-cache` and `flashinfer-cubin`, to pre-compile and download cubins ahead-of-time. - -#### flashinfer-cubin +**Build optional packages:** -To build `flashinfer-cubin` package from source: +`flashinfer-cubin`: ```bash cd flashinfer-cubin python -m build --no-isolation --wheel python -m pip install dist/*.whl ``` -#### flashinfer-jit-cache - -To build `flashinfer-jit-cache` package from source: +`flashinfer-jit-cache` (customize `FLASHINFER_CUDA_ARCH_LIST` for your target GPUs): ```bash -export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 12.0a" # user can shrink the list to specific architectures +export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 12.0a" cd flashinfer-jit-cache python -m build --no-isolation --wheel python -m pip install dist/*.whl ``` -For more details, refer to the [Install from Source documentation](https://docs.flashinfer.ai/installation.html#install-from-source). +For more details, see the [Install from Source documentation](https://docs.flashinfer.ai/installation.html#install-from-source). + +### Install Nightly Build + +Nightly builds are available for testing the latest features: + +```bash +# Core and cubin packages +pip install -U --pre flashinfer-python --extra-index-url https://flashinfer.ai/whl/nightly/ +pip install -U --pre flashinfer-cubin --index-url https://flashinfer.ai/whl/nightly/ +# JIT cache package (replace cu129 with your CUDA version: cu128, cu129, or cu130) +pip install -U --pre flashinfer-jit-cache --index-url https://flashinfer.ai/whl/nightly/cu129 +``` + +### Verify Installation + +After installation, verify that FlashInfer is correctly installed and configured: + +```bash +flashinfer show-config +``` + +This command displays: +- FlashInfer version and installed packages (flashinfer-python, flashinfer-cubin, flashinfer-jit-cache) +- PyTorch and CUDA version information +- Environment variables and artifact paths +- Downloaded cubin status and module compilation status ### Trying it out diff --git a/build_backend.py b/build_backend.py new file mode 100644 index 0000000000..9672dfb2e4 --- /dev/null +++ b/build_backend.py @@ -0,0 +1,164 @@ +""" +Copyright (c) 2023 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import shutil +from pathlib import Path + +from setuptools import build_meta as orig +from build_utils import get_git_version + +_root = Path(__file__).parent.resolve() +_data_dir = _root / "flashinfer" / "data" + + +def _create_build_metadata(): + """Create build metadata file with version information.""" + version_file = _root / "version.txt" + if version_file.exists(): + with open(version_file, "r") as f: + version = f.read().strip() + else: + version = "0.0.0+unknown" + + # Add dev suffix if specified + dev_suffix = os.environ.get("FLASHINFER_DEV_RELEASE_SUFFIX", "") + if dev_suffix: + version = f"{version}.dev{dev_suffix}" + + # Get git version + git_version = get_git_version(cwd=_root) + + # Create build metadata in the source tree + package_dir = Path(__file__).parent / "flashinfer" + build_meta_file = package_dir / "_build_meta.py" + + # Check if we're in a git repository + git_dir = Path(__file__).parent / ".git" + in_git_repo = git_dir.exists() + + # If file exists and not in git repo (installing from sdist), keep existing file + if build_meta_file.exists() and not in_git_repo: + print("Build metadata file already exists (not in git repo), keeping it") + return version + + # In git repo (editable) or file doesn't exist, create/update it + with open(build_meta_file, "w") as f: + f.write('"""Build metadata for flashinfer package."""\n') + f.write(f'__version__ = "{version}"\n') + f.write(f'__git_version__ = "{git_version}"\n') + + print(f"Created build metadata file with version {version}") + return version + + +# Create build metadata as soon as this module is imported +_create_build_metadata() + + +def write_if_different(path: Path, content: str) -> None: + if path.exists() and path.read_text() == content: + return + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(content) + + +def _create_data_dir(use_symlinks=True): + _data_dir.mkdir(parents=True, exist_ok=True) + + def ln(source: str, target: str) -> None: + src = _root / source + dst = _data_dir / target + if dst.exists(): + if dst.is_symlink(): + dst.unlink() + elif dst.is_dir(): + shutil.rmtree(dst) + else: + dst.unlink() + + if use_symlinks: + dst.symlink_to(src, target_is_directory=True) + else: + # For wheel/sdist, copy actual files instead of symlinks + if src.exists(): + shutil.copytree(src, dst, symlinks=False, dirs_exist_ok=True) + + ln("3rdparty/cutlass", "cutlass") + ln("3rdparty/spdlog", "spdlog") + ln("csrc", "csrc") + ln("include", "include") + + +def _prepare_for_wheel(): + # For wheel, copy actual files instead of symlinks so they are included in the wheel + if _data_dir.exists(): + shutil.rmtree(_data_dir) + _create_data_dir(use_symlinks=False) + + +def _prepare_for_editable(): + # For editable install, use symlinks so changes are reflected immediately + if _data_dir.exists(): + shutil.rmtree(_data_dir) + _create_data_dir(use_symlinks=True) + + +def _prepare_for_sdist(): + # For sdist, copy actual files instead of symlinks so they are included in the tarball + if _data_dir.exists(): + shutil.rmtree(_data_dir) + _create_data_dir(use_symlinks=False) + + +def get_requires_for_build_wheel(config_settings=None): + _prepare_for_wheel() + return [] + + +def get_requires_for_build_sdist(config_settings=None): + _prepare_for_sdist() + return [] + + +def get_requires_for_build_editable(config_settings=None): + _prepare_for_editable() + return [] + + +def prepare_metadata_for_build_wheel(metadata_directory, config_settings=None): + _prepare_for_wheel() + return orig.prepare_metadata_for_build_wheel(metadata_directory, config_settings) + + +def prepare_metadata_for_build_editable(metadata_directory, config_settings=None): + _prepare_for_editable() + return orig.prepare_metadata_for_build_editable(metadata_directory, config_settings) + + +def build_editable(wheel_directory, config_settings=None, metadata_directory=None): + _prepare_for_editable() + return orig.build_editable(wheel_directory, config_settings, metadata_directory) + + +def build_sdist(sdist_directory, config_settings=None): + _prepare_for_sdist() + return orig.build_sdist(sdist_directory, config_settings) + + +def build_wheel(wheel_directory, config_settings=None, metadata_directory=None): + _prepare_for_wheel() + return orig.build_wheel(wheel_directory, config_settings, metadata_directory) diff --git a/build_utils.py b/build_utils.py new file mode 100644 index 0000000000..726a628204 --- /dev/null +++ b/build_utils.py @@ -0,0 +1,46 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Shared build utilities for flashinfer packages.""" + +import subprocess +from pathlib import Path +from typing import Optional + + +def get_git_version(cwd: Optional[Path] = None) -> str: + """ + Get git commit hash. + + Args: + cwd: Working directory for git command. If None, uses current directory. + + Returns: + Git commit hash or "unknown" if git is not available. + """ + try: + git_version = ( + subprocess.check_output( + ["git", "rev-parse", "HEAD"], + cwd=cwd, + stderr=subprocess.DEVNULL, + ) + .decode("ascii") + .strip() + ) + return git_version + except Exception: + return "unknown" diff --git a/custom_backend.py b/custom_backend.py deleted file mode 100644 index 0484d714fd..0000000000 --- a/custom_backend.py +++ /dev/null @@ -1,80 +0,0 @@ -import shutil -from pathlib import Path - -from setuptools import build_meta as orig - -_root = Path(__file__).parent.resolve() -_data_dir = _root / "flashinfer" / "data" - - -def _create_data_dir(): - _data_dir.mkdir(parents=True, exist_ok=True) - - def ln(source: str, target: str) -> None: - src = _root / source - dst = _data_dir / target - if dst.exists(): - if dst.is_symlink(): - dst.unlink() - elif dst.is_dir(): - dst.rmdir() - dst.symlink_to(src, target_is_directory=True) - - ln("3rdparty/cutlass", "cutlass") - ln("3rdparty/spdlog", "spdlog") - ln("csrc", "csrc") - ln("include", "include") - - -def _prepare_for_wheel(): - # Remove data directory - if _data_dir.exists(): - shutil.rmtree(_data_dir) - - -def _prepare_for_editable(): - _create_data_dir() - - -def _prepare_for_sdist(): - # Remove data directory - if _data_dir.exists(): - shutil.rmtree(_data_dir) - - -def get_requires_for_build_wheel(config_settings=None): - _prepare_for_wheel() - - -def get_requires_for_build_sdist(config_settings=None): - _prepare_for_sdist() - return [] - - -def get_requires_for_build_editable(config_settings=None): - _prepare_for_editable() - - -def prepare_metadata_for_build_wheel(metadata_directory, config_settings=None): - _prepare_for_wheel() - return orig.prepare_metadata_for_build_wheel(metadata_directory, config_settings) - - -def prepare_metadata_for_build_editable(metadata_directory, config_settings=None): - _prepare_for_editable() - return orig.prepare_metadata_for_build_editable(metadata_directory, config_settings) - - -def build_editable(wheel_directory, config_settings=None, metadata_directory=None): - _prepare_for_editable() - return orig.build_editable(wheel_directory, config_settings, metadata_directory) - - -def build_sdist(sdist_directory, config_settings=None): - _prepare_for_sdist() - return orig.build_sdist(sdist_directory, config_settings) - - -def build_wheel(wheel_directory, config_settings=None, metadata_directory=None): - _prepare_for_wheel() - return orig.build_wheel(wheel_directory, config_settings, metadata_directory) diff --git a/docs/conf.py b/docs/conf.py index 2d9ccf5366..8b61d26c8c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,4 +1,3 @@ -from pathlib import Path from typing import Any, List import flashinfer # noqa: F401 @@ -12,7 +11,6 @@ # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -root = Path(__file__).parents[1].resolve() # FlashInfer is installed via pip before building docs autodoc_mock_imports = [ "torch", @@ -28,9 +26,8 @@ author = "FlashInfer Contributors" copyright = f"2023-2025, {author}" -package_version = (root / "version.txt").read_text().strip() -version = package_version -release = package_version +version = flashinfer.__version__ +release = flashinfer.__version__ # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/docs/installation.rst b/docs/installation.rst index 3700dc0f3f..fe4d83571b 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -24,23 +24,31 @@ The easiest way to install FlashInfer is via pip. Please note that the package c pip install flashinfer-python +Package Options +""""""""""""""" -.. _install-from-source: +FlashInfer provides three packages: -Install from Source -^^^^^^^^^^^^^^^^^^^ +- **flashinfer-python**: Core package that compiles/downloads kernels on first use +- **flashinfer-cubin**: Pre-compiled kernel binaries for all supported GPU architectures +- **flashinfer-jit-cache**: Pre-built kernel cache for specific CUDA versions -In certain cases, you may want to install FlashInfer from source code to try out the latest features in the main branch, or to customize the library for your specific needs. +**For faster initialization and offline usage**, install the optional packages to have most kernels pre-compiled: -``flashinfer-python`` is a source-only package and by default it will JIT compile/download kernels on-the-fly. +.. code-block:: bash + + pip install flashinfer-python flashinfer-cubin + pip install flashinfer-jit-cache --index-url https://flashinfer.ai/whl/ -For fully offline deployment, we also provide two additional packages to pre-compile and download cubins ahead-of-time: +This eliminates compilation and downloading overhead at runtime. -flashinfer-cubin - - Provides pre-compiled CUDA binaries for immediate use without runtime compilation. -flashinfer-jit-cache - - Pre-compiles kernels for specific CUDA architectures to enable fully offline deployment. +.. _install-from-source: + +Install from Source +^^^^^^^^^^^^^^^^^^^ + +In certain cases, you may want to install FlashInfer from source code to try out the latest features in the main branch, or to customize the library for your specific needs. You can follow the steps below to install FlashInfer from source code: @@ -63,15 +71,15 @@ You can follow the steps below to install FlashInfer from source code: cd flashinfer python -m pip install -v . - For development & contribution, install in editable mode: + **For development**, install in editable mode: .. code-block:: bash python -m pip install --no-build-isolation -e . -v -4. (Optional) Build additional packages for offline deployment: +4. (Optional) Build optional packages: - To build ``flashinfer-cubin`` package from source: + Build ``flashinfer-cubin``: .. code-block:: bash @@ -79,11 +87,41 @@ You can follow the steps below to install FlashInfer from source code: python -m build --no-isolation --wheel python -m pip install dist/*.whl - To build ``flashinfer-jit-cache`` package from source: + Build ``flashinfer-jit-cache`` (customize ``FLASHINFER_CUDA_ARCH_LIST`` for your target GPUs): .. code-block:: bash - export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 12.0a" # user can shrink the list to specific architectures + export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 12.0a" cd flashinfer-jit-cache python -m build --no-isolation --wheel python -m pip install dist/*.whl + + +Install Nightly Build +^^^^^^^^^^^^^^^^^^^^^^ + +Nightly builds are available for testing the latest features: + +.. code-block:: bash + + # Core and cubin packages + pip install -U --pre flashinfer-python --extra-index-url https://flashinfer.ai/whl/nightly/ + pip install -U --pre flashinfer-cubin --index-url https://flashinfer.ai/whl/nightly/ + # JIT cache package (replace cu129 with your CUDA version: cu128, cu129, or cu130) + pip install -U --pre flashinfer-jit-cache --index-url https://flashinfer.ai/whl/nightly/cu129 + +Verify Installation +^^^^^^^^^^^^^^^^^^^ + +After installation, verify that FlashInfer is correctly installed and configured: + +.. code-block:: bash + + flashinfer show-config + +This command displays: + +- FlashInfer version and installed packages (flashinfer-python, flashinfer-cubin, flashinfer-jit-cache) +- PyTorch and CUDA version information +- Environment variables and artifact paths +- Downloaded cubin status and module compilation status diff --git a/flashinfer-cubin/build_backend.py b/flashinfer-cubin/build_backend.py index 8815390ad5..0021c1da67 100644 --- a/flashinfer-cubin/build_backend.py +++ b/flashinfer-cubin/build_backend.py @@ -6,20 +6,14 @@ import sys from pathlib import Path from setuptools import build_meta as _orig -from setuptools.build_meta import * # Add parent directory to path to import artifacts module sys.path.insert(0, str(Path(__file__).parent.parent)) -# add flashinfer._build_meta, always override to ensure version is up-to-date -build_meta_file = Path(__file__).parent.parent / "flashinfer" / "_build_meta.py" -version_file = Path(__file__).parent.parent / "version.txt" -if version_file.exists(): - with open(version_file, "r") as f: - version = f.read().strip() -with open(build_meta_file, "w") as f: - f.write('"""Build metadata for flashinfer package."""\n') - f.write(f'__version__ = "{version}"\n') +from build_utils import get_git_version + +# Skip version check when building flashinfer-cubin package +os.environ["FLASHINFER_DISABLE_VERSION_CHECK"] = "1" def _download_cubins(): @@ -60,29 +54,50 @@ def _create_build_metadata(): else: version = "0.0.0+unknown" + # Add dev suffix if specified + dev_suffix = os.environ.get("FLASHINFER_DEV_RELEASE_SUFFIX", "") + if dev_suffix: + version = f"{version}.dev{dev_suffix}" + + # Get git version + git_version = get_git_version(cwd=Path(__file__).parent.parent) + # Create build metadata in the source tree package_dir = Path(__file__).parent / "flashinfer_cubin" build_meta_file = package_dir / "_build_meta.py" + # Check if we're in a git repository + git_dir = Path(__file__).parent.parent / ".git" + in_git_repo = git_dir.exists() + + # If file exists and not in git repo (installing from sdist), keep existing file + if build_meta_file.exists() and not in_git_repo: + print("Build metadata file already exists (not in git repo), keeping it") + return version + + # In git repo (editable) or file doesn't exist, create/update it with open(build_meta_file, "w") as f: f.write('"""Build metadata for flashinfer-cubin package."""\n') f.write(f'__version__ = "{version}"\n') + f.write(f'__git_version__ = "{git_version}"\n') print(f"Created build metadata file with version {version}") return version +# Create build metadata as soon as this module is imported +_create_build_metadata() + + def build_wheel(wheel_directory, config_settings=None, metadata_directory=None): """Build a wheel, downloading cubins first.""" _download_cubins() - _create_build_metadata() return _orig.build_wheel(wheel_directory, config_settings, metadata_directory) def build_editable(wheel_directory, config_settings=None, metadata_directory=None): """Build an editable install, downloading cubins first.""" _download_cubins() - _create_build_metadata() return _orig.build_editable(wheel_directory, config_settings, metadata_directory) diff --git a/flashinfer-cubin/download_cubins.py b/flashinfer-cubin/download_cubins.py deleted file mode 100644 index 2f2c847bbe..0000000000 --- a/flashinfer-cubin/download_cubins.py +++ /dev/null @@ -1,80 +0,0 @@ -#!/usr/bin/env python3 -""" -Copyright (c) 2025 by FlashInfer team. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import os -import sys -import argparse -from pathlib import Path - -# Add parent directory to path to import flashinfer modules -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from flashinfer.artifacts import download_artifacts -from flashinfer.jit.cubin_loader import FLASHINFER_CUBINS_REPOSITORY - - -def main(): - parser = argparse.ArgumentParser( - description="Download FlashInfer cubins from artifactory" - ) - parser.add_argument( - "--output-dir", - "-o", - type=str, - default="flashinfer_cubin/cubins", - help="Output directory for cubins (default: flashinfer_cubin/cubins)", - ) - parser.add_argument( - "--threads", - "-t", - type=int, - default=4, - help="Number of download threads (default: 4)", - ) - parser.add_argument( - "--repository", - "-r", - type=str, - default=None, - help="Override the cubins repository URL", - ) - - args = parser.parse_args() - - # Set environment variables to control download behavior - if args.repository: - os.environ["FLASHINFER_CUBINS_REPOSITORY"] = args.repository - - os.environ["FLASHINFER_CUBIN_DIR"] = str(Path(args.output_dir).absolute()) - os.environ["FLASHINFER_CUBIN_DOWNLOAD_THREADS"] = str(args.threads) - - print(f"Downloading cubins to {args.output_dir}") - print( - f"Repository: {os.environ.get('FLASHINFER_CUBINS_REPOSITORY', FLASHINFER_CUBINS_REPOSITORY)}" - ) - - # Use the existing download_artifacts function - try: - download_artifacts() - print("Download complete!") - except Exception as e: - print(f"Download failed: {e}") - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/flashinfer-cubin/flashinfer_cubin/__init__.py b/flashinfer-cubin/flashinfer_cubin/__init__.py index b28fb09da5..abfe816282 100644 --- a/flashinfer-cubin/flashinfer_cubin/__init__.py +++ b/flashinfer-cubin/flashinfer_cubin/__init__.py @@ -63,5 +63,18 @@ def _get_version(): return "0.0.0" +def _get_git_version(): + # First try to read from build metadata (for wheel distributions) + try: + from . import _build_meta + + return _build_meta.__git_version__ + except (ImportError, AttributeError): + pass + + return "unknown" + + __version__ = _get_version() +__git_version__ = _get_git_version() __all__ = ["get_cubin_dir", "list_cubins", "get_cubin_path", "CUBIN_DIR"] diff --git a/flashinfer-jit-cache/build_backend.py b/flashinfer-jit-cache/build_backend.py index 7e0d9f4aaa..b9a1739070 100644 --- a/flashinfer-jit-cache/build_backend.py +++ b/flashinfer-jit-cache/build_backend.py @@ -16,24 +16,22 @@ import sys import os +import platform from pathlib import Path from setuptools import build_meta as _orig +from wheel.bdist_wheel import bdist_wheel # Add parent directory to path to import flashinfer modules sys.path.insert(0, str(Path(__file__).parent.parent)) -# add flashinfer._build_meta, always override to ensure version is up-to-date -build_meta_file = Path(__file__).parent.parent / "flashinfer" / "_build_meta.py" -version_file = Path(__file__).parent.parent / "version.txt" -if version_file.exists(): - with open(version_file, "r") as f: - version = f.read().strip() -with open(build_meta_file, "w") as f: - f.write('"""Build metadata for flashinfer package."""\n') - f.write(f'__version__ = "{version}"\n') +from build_utils import get_git_version +# Skip version check when building flashinfer-jit-cache package +os.environ["FLASHINFER_DISABLE_VERSION_CHECK"] = "1" -def get_version(): + +def _create_build_metadata(): + """Create build metadata file with version information.""" version_file = Path(__file__).parent.parent / "version.txt" if version_file.exists(): with open(version_file, "r") as f: @@ -41,20 +39,45 @@ def get_version(): else: version = "0.0.0+unknown" + # Add dev suffix if specified + dev_suffix = os.environ.get("FLASHINFER_DEV_RELEASE_SUFFIX", "") + if dev_suffix: + version = f"{version}.dev{dev_suffix}" + + # Get git version + git_version = get_git_version(cwd=Path(__file__).parent.parent) + # Append CUDA version suffix if available cuda_suffix = os.environ.get("CUDA_VERSION_SUFFIX", "") if cuda_suffix: - # Replace + with . for proper version formatting - if "+" in version: - base_version, local = version.split("+", 1) - version = f"{base_version}+{cuda_suffix}.{local}" - else: - version = f"{version}+{cuda_suffix}" + # Use + to create a local version identifier that will appear in wheel name + version = f"{version}+{cuda_suffix}" + build_meta_file = Path(__file__).parent / "flashinfer_jit_cache" / "_build_meta.py" + + # Check if we're in a git repository + git_dir = Path(__file__).parent.parent / ".git" + in_git_repo = git_dir.exists() + + # If file exists and not in git repo (installing from sdist), keep existing file + if build_meta_file.exists() and not in_git_repo: + print("Build metadata file already exists (not in git repo), keeping it") + return version + + # In git repo (editable) or file doesn't exist, create/update it + with open(build_meta_file, "w") as f: + f.write('"""Build metadata for flashinfer-jit-cache package."""\n') + f.write(f'__version__ = "{version}"\n') + f.write(f'__git_version__ = "{git_version}"\n') + print(f"Created build metadata file with version {version}") return version -def compile_jit_cache(output_dir: Path, verbose: bool = True): +# Create build metadata as soon as this module is imported +_create_build_metadata() + + +def _compile_jit_cache(output_dir: Path, verbose: bool = True): """Compile AOT modules using flashinfer.aot functions directly.""" from flashinfer import aot @@ -75,15 +98,14 @@ def compile_jit_cache(output_dir: Path, verbose: bool = True): ) -def _prepare_build(): - """Shared preparation logic for both wheel and editable builds.""" +def _build_aot_modules(): # First, ensure AOT modules are compiled aot_package_dir = Path(__file__).parent / "flashinfer_jit_cache" / "jit_cache" aot_package_dir.mkdir(parents=True, exist_ok=True) try: # Compile AOT modules - compile_jit_cache(aot_package_dir) + _compile_jit_cache(aot_package_dir) # Verify that some modules were actually compiled so_files = list(aot_package_dir.rglob("*.so")) @@ -96,16 +118,61 @@ def _prepare_build(): print(f"Failed to compile AOT modules: {e}") raise - # Create build metadata file with version information - package_dir = Path(__file__).parent / "flashinfer_jit_cache" - build_meta_file = package_dir / "_build_meta.py" - version = get_version() - with open(build_meta_file, "w") as f: - f.write('"""Build metadata for flashinfer-jit-cache package."""\n') - f.write(f'__version__ = "{version}"\n') +def _prepare_build(): + """Shared preparation logic for both wheel and editable builds.""" + _build_aot_modules() + + +class PlatformSpecificBdistWheel(bdist_wheel): + """Custom wheel builder that uses py_limited_api for cp39+.""" + + def finalize_options(self): + super().finalize_options() + # Force platform-specific wheel (not pure Python) + self.root_is_pure = False + # Use py_limited_api for cp39 (Python 3.9+) + self.py_limited_api = "cp39" + + def get_tag(self): + # Use py_limited_api tags + python_tag = "cp39" + abi_tag = "abi3" # Stable ABI tag + + # Get platform tag + machine = platform.machine() + if platform.system() == "Linux": + # Use manylinux_2_28 as specified + if machine == "x86_64": + plat_tag = "manylinux_2_28_x86_64" + elif machine == "aarch64": + plat_tag = "manylinux_2_28_aarch64" + else: + plat_tag = f"linux_{machine}" + else: + # For non-Linux platforms, use the default + import distutils.util - print(f"Created build metadata file with version {version}") + plat_tag = distutils.util.get_platform().replace("-", "_").replace(".", "_") + + return python_tag, abi_tag, plat_tag + + +class _MonkeyPatchBdistWheel: + """Context manager to temporarily replace bdist_wheel with our custom class.""" + + def __enter__(self): + from setuptools.command import bdist_wheel as setuptools_bdist_wheel + + self.original_bdist_wheel = setuptools_bdist_wheel.bdist_wheel + setuptools_bdist_wheel.bdist_wheel = PlatformSpecificBdistWheel + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + from setuptools.command import bdist_wheel as setuptools_bdist_wheel + + setuptools_bdist_wheel.bdist_wheel = self.original_bdist_wheel def build_wheel(wheel_directory, config_settings=None, metadata_directory=None): @@ -114,11 +181,8 @@ def build_wheel(wheel_directory, config_settings=None, metadata_directory=None): _prepare_build() - # Now build the wheel using setuptools - # The setup.py file will handle the platform-specific wheel naming - result = _orig.build_wheel(wheel_directory, config_settings, metadata_directory) - - return result + with _MonkeyPatchBdistWheel(): + return _orig.build_wheel(wheel_directory, config_settings, metadata_directory) def build_editable(wheel_directory, config_settings=None, metadata_directory=None): @@ -137,9 +201,24 @@ def build_editable(wheel_directory, config_settings=None, metadata_directory=Non return result +def prepare_metadata_for_build_wheel(metadata_directory, config_settings=None): + """Prepare metadata with platform-specific wheel tags.""" + with _MonkeyPatchBdistWheel(): + return _orig.prepare_metadata_for_build_wheel( + metadata_directory, config_settings + ) + + +def prepare_metadata_for_build_editable(metadata_directory, config_settings=None): + """Prepare metadata for editable install.""" + with _MonkeyPatchBdistWheel(): + return _orig.prepare_metadata_for_build_editable( + metadata_directory, config_settings + ) + + # Export the required interface get_requires_for_build_wheel = _orig.get_requires_for_build_wheel -prepare_metadata_for_build_wheel = _orig.prepare_metadata_for_build_wheel get_requires_for_build_editable = getattr( _orig, "get_requires_for_build_editable", None ) diff --git a/flashinfer-jit-cache/flashinfer_jit_cache/__init__.py b/flashinfer-jit-cache/flashinfer_jit_cache/__init__.py index 986232c63a..3be467e99a 100644 --- a/flashinfer-jit-cache/flashinfer_jit_cache/__init__.py +++ b/flashinfer-jit-cache/flashinfer_jit_cache/__init__.py @@ -29,8 +29,10 @@ def get_jit_cache_dir() -> str: try: from ._build_meta import __version__ as __version__ + from ._build_meta import __git_version__ as __git_version__ except ModuleNotFoundError: __version__ = "0.0.0+unknown" + __git_version__ = "unknown" __all__ = [ diff --git a/flashinfer-jit-cache/setup.py b/flashinfer-jit-cache/setup.py deleted file mode 100644 index e65b216cce..0000000000 --- a/flashinfer-jit-cache/setup.py +++ /dev/null @@ -1,109 +0,0 @@ -"""Setup script for flashinfer-jit-cache package.""" - -import os -import platform -from pathlib import Path -from setuptools import setup, find_packages -from wheel.bdist_wheel import bdist_wheel - - -def get_version(): - """Get version from version.txt file.""" - version_file = Path(__file__).parent.parent / "version.txt" - if version_file.exists(): - with open(version_file, "r") as f: - version = f.read().strip() - else: - version = "0.0.0" - - # Append CUDA version suffix if available - cuda_suffix = os.environ.get("CUDA_VERSION_SUFFIX", "") - if cuda_suffix: - # Use + to create a local version identifier that will appear in wheel name - version = f"{version}+{cuda_suffix}" - - return version - - -def generate_build_meta(): - """Generate build metadata file.""" - build_meta_file = Path(__file__).parent / "flashinfer_jit_cache" / "_build_meta.py" - version = get_version() - with open(build_meta_file, "w") as f: - f.write('"""Build metadata for flashinfer-jit-cache package."""\n') - f.write(f'__version__ = "{version}"\n') - - -class PlatformSpecificBdistWheel(bdist_wheel): - """Custom wheel builder that uses py_limited_api for cp39+.""" - - def finalize_options(self): - super().finalize_options() - # Force platform-specific wheel (not pure Python) - self.root_is_pure = False - # Use py_limited_api for cp39 (Python 3.9+) - self.py_limited_api = "cp39" - - def get_tag(self): - # Use py_limited_api tags - python_tag = "cp39" - abi_tag = "abi3" # Stable ABI tag - - # Get platform tag - machine = platform.machine() - if platform.system() == "Linux": - # Use manylinux_2_28 as specified - if machine == "x86_64": - plat_tag = "manylinux_2_28_x86_64" - elif machine == "aarch64": - plat_tag = "manylinux_2_28_aarch64" - else: - plat_tag = f"linux_{machine}" - else: - # For non-Linux platforms, use the default - import distutils.util - - plat_tag = distutils.util.get_platform().replace("-", "_").replace(".", "_") - - return python_tag, abi_tag, plat_tag - - -if __name__ == "__main__": - generate_build_meta() - setup( - name="flashinfer-jit-cache", - version=get_version(), - description="Pre-compiled AOT modules for FlashInfer", - long_description="This package contains pre-compiled AOT modules for FlashInfer. It provides all necessary compiled shared libraries (.so files) for optimized inference operations.", - long_description_content_type="text/plain", - author="FlashInfer team", - maintainer="FlashInfer team", - url="https://github.com/flashinfer-ai/flashinfer", - project_urls={ - "Homepage": "https://github.com/flashinfer-ai/flashinfer", - "Documentation": "https://github.com/flashinfer-ai/flashinfer", - "Repository": "https://github.com/flashinfer-ai/flashinfer", - "Issue Tracker": "https://github.com/flashinfer-ai/flashinfer/issues", - }, - packages=find_packages(), - package_data={ - "flashinfer_jit_cache": ["jit_cache/**/*.so"], - }, - include_package_data=True, - python_requires=">=3.9", - classifiers=[ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "Operating System :: POSIX :: Linux", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", - "Topic :: Software Development :: Libraries :: Python Modules", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - ], - license="Apache-2.0", - cmdclass={"bdist_wheel": PlatformSpecificBdistWheel}, - ) diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index ccfae46ee7..866a91351c 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -16,10 +16,8 @@ import importlib.util -try: - from ._build_meta import __version__ as __version__ -except ModuleNotFoundError: - __version__ = "0.0.0+unknown" +from .version import __version__ as __version__ +from .version import __git_version__ as __git_version__ from . import jit as jit diff --git a/flashinfer/__main__.py b/flashinfer/__main__.py index 701e1638ec..df3a3b5c4f 100644 --- a/flashinfer/__main__.py +++ b/flashinfer/__main__.py @@ -29,7 +29,9 @@ from .jit.env import FLASHINFER_CACHE_DIR, FLASHINFER_CUBIN_DIR from .jit.core import current_compilation_context from .jit.cpp_ext import get_cuda_path, get_cuda_version -from . import __version__ + +# Import __version__ from centralized version module +from .version import __version__ def _download_cubin(): diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 4d344d02bb..fe4aeab60b 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -15,6 +15,7 @@ """ from dataclasses import dataclass +import logging import os import re import time @@ -23,10 +24,15 @@ import requests # type: ignore[import-untyped] import shutil -from .jit.core import logger +# Create logger for artifacts module to avoid circular import with jit.core +logger = logging.getLogger("flashinfer.artifacts") +logger.setLevel(os.getenv("FLASHINFER_LOGGING_LEVEL", "INFO").upper()) +if not logger.handlers: + logger.addHandler(logging.StreamHandler()) + from .jit.cubin_loader import ( FLASHINFER_CUBINS_REPOSITORY, - get_cubin, + download_file, safe_urljoin, FLASHINFER_CUBIN_DIR, ) @@ -125,26 +131,31 @@ def download_artifacts() -> None: # HTTPS connections. session = requests.Session() - with temp_env_var("FLASHINFER_CUBIN_CHECKSUM_DISABLED", "1"): - cubin_files = list(get_cubin_file_list()) - num_threads = int(os.environ.get("FLASHINFER_CUBIN_DOWNLOAD_THREADS", "4")) - with tqdm_logging_redirect( - total=len(cubin_files), desc="Downloading cubins" - ) as pbar: - - def update_pbar_cb(_) -> None: - pbar.update(1) - - with ThreadPoolExecutor(num_threads) as pool: - futures = [] - for name in cubin_files: - fut = pool.submit(get_cubin, name, "", session) - fut.add_done_callback(update_pbar_cb) - futures.append(fut) - - results = [fut.result() for fut in as_completed(futures)] - - all_success = all(results) + cubin_files = list(get_cubin_file_list()) + num_threads = int(os.environ.get("FLASHINFER_CUBIN_DOWNLOAD_THREADS", "4")) + with tqdm_logging_redirect( + total=len(cubin_files), desc="Downloading cubins" + ) as pbar: + + def update_pbar_cb(_) -> None: + pbar.update(1) + + with ThreadPoolExecutor(num_threads) as pool: + futures = [] + for name in cubin_files: + source = safe_urljoin(FLASHINFER_CUBINS_REPOSITORY, name) + local_path = FLASHINFER_CUBIN_DIR / name + # Ensure parent directory exists + local_path.parent.mkdir(parents=True, exist_ok=True) + fut = pool.submit( + download_file, source, str(local_path), session=session + ) + fut.add_done_callback(update_pbar_cb) + futures.append(fut) + + results = [fut.result() for fut in as_completed(futures)] + + all_success = all(results) if not all_success: raise RuntimeError("Failed to download cubins") diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index 4b7bc8f856..314dee1eb3 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -61,6 +61,7 @@ from .core import build_jit_specs as build_jit_specs from .core import clear_cache_dir as clear_cache_dir from .core import gen_jit_spec as gen_jit_spec +from .core import MissingJITCacheError as MissingJITCacheError from .core import sm90a_nvcc_flags as sm90a_nvcc_flags from .core import sm100a_nvcc_flags as sm100a_nvcc_flags from .core import sm100f_nvcc_flags as sm100f_nvcc_flags diff --git a/flashinfer/jit/attention/modules.py b/flashinfer/jit/attention/modules.py index c4a3ac9754..3fb4a289d3 100644 --- a/flashinfer/jit/attention/modules.py +++ b/flashinfer/jit/attention/modules.py @@ -20,7 +20,6 @@ import jinja2 import torch -from ...artifacts import ArtifactPath, MetaInfoHash from .. import env as jit_env from ..core import ( JitSpec, @@ -1569,6 +1568,8 @@ def gen_fmha_cutlass_sm100a_module( def gen_trtllm_gen_fmha_module(): + from ...artifacts import ArtifactPath, MetaInfoHash + include_path = f"{ArtifactPath.TRTLLM_GEN_FMHA}/include" header_name = "flashInferMetaInfo" @@ -1687,6 +1688,8 @@ def gen_customize_batch_attention_module( def gen_cudnn_fmha_module(): + from ...artifacts import ArtifactPath + return gen_jit_spec( "fmha_cudnn_gen", [jit_env.FLASHINFER_CSRC_DIR / "cudnn_sdpa_kernel_launcher.cu"], diff --git a/flashinfer/jit/core.py b/flashinfer/jit/core.py index e01e5de460..1f987b6036 100644 --- a/flashinfer/jit/core.py +++ b/flashinfer/jit/core.py @@ -18,6 +18,24 @@ os.makedirs(jit_env.FLASHINFER_CSRC_DIR, exist_ok=True) +class MissingJITCacheError(RuntimeError): + """ + Exception raised when JIT compilation is disabled and the JIT cache + does not contain the required precompiled module. + + This error indicates that a module needs to be added to the JIT cache + build configuration. + + Attributes: + spec: JitSpec of the missing module + message: Error message + """ + + def __init__(self, message: str, spec: Optional["JitSpec"] = None): + self.spec = spec + super().__init__(message) + + class FlashInferJITLogger(logging.Logger): def __init__(self, name): super().__init__(name) @@ -228,6 +246,13 @@ def is_ninja_generated(self) -> bool: return self.ninja_path.exists() def build(self, verbose: bool, need_lock: bool = True) -> None: + if os.environ.get("FLASHINFER_DISABLE_JIT"): + raise MissingJITCacheError( + "JIT compilation is disabled via FLASHINFER_DISABLE_JIT environment variable, " + "but the required module is not found in the JIT cache. " + "Please add the missing module to the JIT cache build configuration.", + spec=self, + ) lock = ( FileLock(self.lock_path, thread_local=False) if need_lock else nullcontext() ) @@ -294,7 +319,7 @@ def gen_jit_spec( cflags += ["-O3"] # useful for ncu - if bool(os.environ.get("FLASHINFER_JIT_LINEINFO", "0")): + if os.environ.get("FLASHINFER_JIT_LINEINFO", "0") == "1": cuda_cflags += ["-lineinfo"] if extra_cflags is not None: diff --git a/flashinfer/jit/cubin_loader.py b/flashinfer/jit/cubin_loader.py index 7d20281225..78615e86f9 100644 --- a/flashinfer/jit/cubin_loader.py +++ b/flashinfer/jit/cubin_loader.py @@ -44,7 +44,7 @@ def safe_urljoin(base, path): def download_file( source: str, local_path: str, - retries: int = 3, + retries: int = 4, delay: int = 5, timeout: int = 10, lock_timeout: int = 30, @@ -57,7 +57,7 @@ def download_file( - source (str): The URL or local file path of the file to download. - local_path (str): The local file path to save the downloaded/copied file. - retries (int): Number of retry attempts for URL downloads (default: 3). - - delay (int): Delay in seconds between retries (default: 5). + - delay (int): Initial delay in seconds for exponential backoff (default: 5). - timeout (int): Timeout for the HTTP request in seconds (default: 10). - lock_timeout (int): Timeout in seconds for the file lock (default: 30). @@ -87,7 +87,7 @@ def download_file( logger.error(f"Failed to copy local file: {e}") return False - # Handle URL downloads + # Handle URL downloads with exponential backoff for attempt in range(1, retries + 1): try: response = session.get(source, timeout=timeout) @@ -107,8 +107,9 @@ def download_file( ) if attempt < retries: - logger.info(f"Retrying in {delay} seconds...") - time.sleep(delay) + backoff_delay = delay * (2 ** (attempt - 1)) + logger.info(f"Retrying in {backoff_delay} seconds...") + time.sleep(backoff_delay) else: logger.error("Max retries reached. Download failed.") return False diff --git a/flashinfer/jit/env.py b/flashinfer/jit/env.py index 057ac97978..4f50552d71 100644 --- a/flashinfer/jit/env.py +++ b/flashinfer/jit/env.py @@ -20,9 +20,33 @@ import os import pathlib -import importlib.util from ..compilation_context import CompilationContext -from .. import __version__ as flashinfer_version +from ..version import __version__ as flashinfer_version + + +def has_flashinfer_jit_cache() -> bool: + """ + Check if flashinfer_jit_cache module is available. + + Returns: + True if flashinfer_jit_cache exists, False otherwise + """ + import importlib.util + + return importlib.util.find_spec("flashinfer_jit_cache") is not None + + +def has_flashinfer_cubin() -> bool: + """ + Check if flashinfer_cubin module is available. + + Returns: + True if flashinfer_cubin exists, False otherwise + """ + import importlib.util + + return importlib.util.find_spec("flashinfer_cubin") is not None + FLASHINFER_BASE_DIR: pathlib.Path = pathlib.Path( os.getenv("FLASHINFER_WORKSPACE_BASE", pathlib.Path.home().as_posix()) @@ -40,15 +64,20 @@ def _get_cubin_dir(): 3. Default cache directory """ # First check if flashinfer-cubin package is installed - if importlib.util.find_spec("flashinfer_cubin"): + if has_flashinfer_cubin(): import flashinfer_cubin flashinfer_cubin_version = flashinfer_cubin.__version__ - if flashinfer_version != flashinfer_cubin_version: + # Allow bypassing version check with environment variable + if ( + not os.getenv("FLASHINFER_DISABLE_VERSION_CHECK") + and flashinfer_version != flashinfer_cubin_version + ): raise RuntimeError( f"flashinfer-cubin version ({flashinfer_cubin_version}) does not match " f"flashinfer version ({flashinfer_version}). " - "Please install the same version of both packages." + "Please install the same version of both packages. " + "Set FLASHINFER_DISABLE_VERSION_CHECK=1 to bypass this check." ) return pathlib.Path(flashinfer_cubin.get_cubin_dir()) @@ -72,17 +101,21 @@ def _get_aot_dir(): 2. Default fallback to _package_root / "data" / "aot" """ # First check if flashinfer-jit-cache package is installed - if importlib.util.find_spec("flashinfer_jit_cache"): + if has_flashinfer_jit_cache(): import flashinfer_jit_cache flashinfer_jit_cache_version = flashinfer_jit_cache.__version__ # NOTE(Zihao): we don't use exact version match here because the version of flashinfer-jit-cache # contains the CUDA version suffix: e.g. 0.3.1+cu129. - if not flashinfer_jit_cache_version.startswith(flashinfer_version): + # Allow bypassing version check with environment variable + if not os.getenv( + "FLASHINFER_DISABLE_VERSION_CHECK" + ) and not flashinfer_jit_cache_version.startswith(flashinfer_version): raise RuntimeError( f"flashinfer-jit-cache version ({flashinfer_jit_cache_version}) does not match " f"flashinfer version ({flashinfer_version}). " - "Please install the same version of both packages." + "Please install the same version of both packages. " + "Set FLASHINFER_DISABLE_VERSION_CHECK=1 to bypass this check." ) return pathlib.Path(flashinfer_jit_cache.get_jit_cache_dir()) diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 448f6d116a..d107c88298 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -450,6 +450,13 @@ def has_cuda_cudart() -> bool: return importlib.util.find_spec("cuda.cudart") is not None +# Re-export from jit.env to avoid circular dependency +from .jit.env import ( + has_flashinfer_jit_cache as has_flashinfer_jit_cache, + has_flashinfer_cubin as has_flashinfer_cubin, +) + + def get_cuda_python_version() -> str: import cuda diff --git a/flashinfer/version.py b/flashinfer/version.py new file mode 100644 index 0000000000..95ad245497 --- /dev/null +++ b/flashinfer/version.py @@ -0,0 +1,23 @@ +""" +Copyright (c) 2023 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Centralized version information to avoid circular imports +try: + from ._build_meta import __version__ as __version__ + from ._build_meta import __git_version__ as __git_version__ +except ModuleNotFoundError: + __version__ = "0.0.0+unknown" + __git_version__ = "unknown" diff --git a/pyproject.toml b/pyproject.toml index ef13baa13f..679a8b905d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,9 +23,12 @@ urls = { Homepage = "https://github.com/flashinfer-ai/flashinfer" } dynamic = ["dependencies", "version"] license-files = ["LICENSE", "licenses/*"] +[project.scripts] +flashinfer = "flashinfer.__main__:cli" + [build-system] -requires = ["setuptools>=77", "packaging>=24"] -build-backend = "custom_backend" +requires = ["setuptools>=77", "packaging>=24", "apache-tvm-ffi==0.1.0b15"] +build-backend = "build_backend" backend-path = ["."] [tool.codespell] @@ -39,10 +42,16 @@ skip = [ [tool.setuptools] include-package-data = false +py-modules = ["build_backend", "build_utils"] + +[tool.setuptools.dynamic] +version = {attr = "flashinfer._build_meta.__version__"} +dependencies = {file = ["requirements.txt"]} [tool.setuptools.packages.find] where = ["."] include = ["flashinfer*"] +exclude = ["flashinfer-jit-cache*", "flashinfer-cubin*"] [tool.setuptools.package-dir] "flashinfer.data" = "." @@ -50,10 +59,12 @@ include = ["flashinfer*"] "flashinfer.data.spdlog" = "3rdparty/spdlog" [tool.setuptools.package-data] +"flashinfer" = [ + "_build_meta.py" +] "flashinfer.data" = [ "csrc/**", - "include/**", - "version.txt" + "include/**" ] "flashinfer.data.cutlass" = [ "include/**", diff --git a/pytest.ini b/pytest.ini index 765de5e01a..91154582ca 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,2 +1,3 @@ [pytest] norecursedirs = test_helpers +addopts = --import-mode=importlib diff --git a/scripts/build_flashinfer_jit_cache_whl.sh b/scripts/build_flashinfer_jit_cache_whl.sh index 30495ff377..fd1a313e14 100755 --- a/scripts/build_flashinfer_jit_cache_whl.sh +++ b/scripts/build_flashinfer_jit_cache_whl.sh @@ -27,8 +27,10 @@ echo "CUDA Major: ${CUDA_MAJOR}" echo "CUDA Minor: ${CUDA_MINOR}" echo "CUDA Version Suffix: ${CUDA_VERSION_SUFFIX}" echo "CUDA Architectures: ${FLASHINFER_CUDA_ARCH_LIST}" +echo "Dev Release Suffix: ${FLASHINFER_DEV_RELEASE_SUFFIX}" echo "MAX_JOBS: ${MAX_JOBS}" echo "Python Version: $(python3 --version)" +echo "Git commit: $(git rev-parse HEAD 2>/dev/null || echo 'unknown')" echo "Working directory: $(pwd)" echo "" @@ -62,6 +64,16 @@ echo "" echo "Built wheels:" ls -lh dist/ +# Verify version and git version +echo "" +echo "Verifying version and git version..." +pip install dist/*.whl +python -c " +import flashinfer_jit_cache +print(f'šŸ“¦ Package version: {flashinfer_jit_cache.__version__}') +print(f'šŸ”– Git version: {flashinfer_jit_cache.__git_version__}') +" + # Copy wheels to output directory if specified if [ -n "${OUTPUT_DIR}" ]; then echo "" diff --git a/scripts/print_jit_cache_summary.py b/scripts/print_jit_cache_summary.py new file mode 100644 index 0000000000..c719374d08 --- /dev/null +++ b/scripts/print_jit_cache_summary.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +""" +Print aggregated JIT cache coverage summary from multiple pytest runs. + +This script reads the JSON file written by pytest (via conftest.py) and +prints a unified summary of all missing JIT cache modules across all test runs. + +Usage: + python scripts/print_jit_cache_summary.py /path/to/aggregate.json +""" + +import json +import sys +from pathlib import Path +from typing import Dict, Set + + +def print_jit_cache_summary(aggregate_file: str): + """Read and print JIT cache coverage summary from aggregate file""" + aggregate_path = Path(aggregate_file) + + if not aggregate_path.exists(): + print("No JIT cache coverage data found.") + print(f"Expected file: {aggregate_file}") + return + + # Read all entries from the file + missing_modules: Set[tuple] = set() + with open(aggregate_file, "r") as f: + for line in f: + line = line.strip() + if not line: + continue + entry = json.loads(line) + missing_modules.add( + (entry["test_name"], entry["module_name"], entry["spec_info"]) + ) + + if not missing_modules: + print("āœ… All tests passed - no missing JIT cache modules!") + return + + # Print summary + print("=" * 80) + print("flashinfer-jit-cache Package Coverage Report") + print("=" * 80) + print() + print("This report shows the coverage of the flashinfer-jit-cache package.") + print( + "Tests are skipped when required modules are not found in the installed JIT cache." + ) + print() + print( + f"āš ļø {len(missing_modules)} test(s) skipped due to missing JIT cache modules:" + ) + print() + + # Group by module name + module_to_tests: Dict[str, Dict] = {} + for test_name, module_name, spec_info in missing_modules: + if module_name not in module_to_tests: + module_to_tests[module_name] = {"tests": [], "spec_info": spec_info} + module_to_tests[module_name]["tests"].append(test_name) + + for module_name in sorted(module_to_tests.keys()): + info = module_to_tests[module_name] + print(f"Module: {module_name}") + print(f" Spec: {info['spec_info']}") + print(f" Affected tests ({len(info['tests'])}):") + for test in sorted(info["tests"]): + print(f" - {test}") + print() + + print("These tests require JIT compilation but FLASHINFER_DISABLE_JIT=1 was set.") + print( + "To improve coverage, add the missing modules to the flashinfer-jit-cache build configuration." + ) + print("=" * 80) + + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python scripts/print_jit_cache_summary.py ") + sys.exit(1) + + print_jit_cache_summary(sys.argv[1]) diff --git a/scripts/task_jit_run_tests_part1.sh b/scripts/task_jit_run_tests_part1.sh index 38f89999a3..60deeaac7b 100755 --- a/scripts/task_jit_run_tests_part1.sh +++ b/scripts/task_jit_run_tests_part1.sh @@ -4,12 +4,16 @@ set -eo pipefail set -x : ${MAX_JOBS:=$(nproc)} : ${CUDA_VISIBLE_DEVICES:=0} +: ${SKIP_INSTALL:=0} -pip install -e . -v +if [ "$SKIP_INSTALL" = "0" ]; then + pip install -e . -v +fi -# pytest -s tests/gemm/test_group_gemm.py +# Run each test file separately to isolate CUDA memory issues pytest -s tests/attention/test_logits_cap.py pytest -s tests/attention/test_sliding_window.py pytest -s tests/attention/test_tensor_cores_decode.py pytest -s tests/attention/test_batch_decode_kernels.py +# pytest -s tests/gemm/test_group_gemm.py # pytest -s tests/attention/test_alibi.py diff --git a/scripts/task_jit_run_tests_part2.sh b/scripts/task_jit_run_tests_part2.sh index 7fc15fef25..b4bb6bf17c 100755 --- a/scripts/task_jit_run_tests_part2.sh +++ b/scripts/task_jit_run_tests_part2.sh @@ -4,9 +4,13 @@ set -eo pipefail set -x : ${MAX_JOBS:=$(nproc)} : ${CUDA_VISIBLE_DEVICES:=0} +: ${SKIP_INSTALL:=0} -pip install -e . -v +if [ "$SKIP_INSTALL" = "0" ]; then + pip install -e . -v +fi +# Run each test file separately to isolate CUDA memory issues pytest -s tests/utils/test_block_sparse.py pytest -s tests/utils/test_jit_example.py pytest -s tests/utils/test_jit_warmup.py diff --git a/scripts/task_jit_run_tests_part3.sh b/scripts/task_jit_run_tests_part3.sh index da342eec19..cb59c7e84f 100755 --- a/scripts/task_jit_run_tests_part3.sh +++ b/scripts/task_jit_run_tests_part3.sh @@ -4,7 +4,11 @@ set -eo pipefail set -x : ${MAX_JOBS:=$(nproc)} : ${CUDA_VISIBLE_DEVICES:=0} +: ${SKIP_INSTALL:=0} -pip install -e . -v +if [ "$SKIP_INSTALL" = "0" ]; then + pip install -e . -v +fi +# Run each test file separately to isolate CUDA memory issues pytest -s tests/utils/test_sampling.py diff --git a/scripts/task_jit_run_tests_part4.sh b/scripts/task_jit_run_tests_part4.sh index cea153b6b9..c771fa37a7 100755 --- a/scripts/task_jit_run_tests_part4.sh +++ b/scripts/task_jit_run_tests_part4.sh @@ -4,10 +4,15 @@ set -eo pipefail set -x : ${MAX_JOBS:=$(nproc)} : ${CUDA_VISIBLE_DEVICES:=0} +: ${SKIP_INSTALL:=0} -pip install -e . -v +if [ "$SKIP_INSTALL" = "0" ]; then + pip install -e . -v +fi export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True # avoid memory fragmentation + +# Run each test file separately to isolate CUDA memory issues pytest -s tests/attention/test_deepseek_mla.py pytest -s tests/gemm/test_group_gemm.py pytest -s tests/attention/test_batch_prefill_kernels.py diff --git a/scripts/task_jit_run_tests_part5.sh b/scripts/task_jit_run_tests_part5.sh index 58bb21babd..a4ada8334a 100755 --- a/scripts/task_jit_run_tests_part5.sh +++ b/scripts/task_jit_run_tests_part5.sh @@ -4,7 +4,11 @@ set -eo pipefail set -x : ${MAX_JOBS:=$(nproc)} : ${CUDA_VISIBLE_DEVICES:=0} +: ${SKIP_INSTALL:=0} -pip install -e . -v +if [ "$SKIP_INSTALL" = "0" ]; then + pip install -e . -v +fi +# Run each test file separately to isolate CUDA memory issues pytest -s tests/utils/test_logits_processor.py diff --git a/scripts/task_test_nightly_build.sh b/scripts/task_test_nightly_build.sh new file mode 100755 index 0000000000..46f6b76d36 --- /dev/null +++ b/scripts/task_test_nightly_build.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +set -eo pipefail +set -x + +# This script installs nightly build packages and runs tests +# Expected dist directories to be in current directory or specified via env vars + +: ${TEST_SHARD:=1} +: ${CUDA_VISIBLE_DEVICES:=0} +: ${DIST_CUBIN_DIR:=dist-cubin} +: ${DIST_JIT_CACHE_DIR:=dist-jit-cache} +: ${DIST_PYTHON_DIR:=dist-python} + +# Display GPU information (running inside Docker container with GPU access) +echo "=== GPU Information ===" +nvidia-smi + +# Install flashinfer packages +echo "Installing flashinfer-cubin from ${DIST_CUBIN_DIR}..." +pip install ${DIST_CUBIN_DIR}/*.whl + +echo "Installing flashinfer-jit-cache from ${DIST_JIT_CACHE_DIR}..." +pip install ${DIST_JIT_CACHE_DIR}/*.whl + +# Disable JIT to verify that jit-cache package contains all necessary +# precompiled modules for the test suite to pass without compilation +echo "Disabling JIT compilation to test with precompiled cache only..." +export FLASHINFER_DISABLE_JIT=1 + +echo "Installing flashinfer-python from ${DIST_PYTHON_DIR}..." +pip install ${DIST_PYTHON_DIR}/*.tar.gz + +# Verify installation +echo "Verifying installation..." +# Run from /tmp to avoid importing local flashinfer/ source directory +(cd /tmp && python -m flashinfer show-config) + +# Run test shard +echo "Running test shard ${TEST_SHARD}..." +export SKIP_INSTALL=1 + +# Pass through JIT cache report file if set +if [ -n "${FLASHINFER_JIT_CACHE_REPORT_FILE}" ]; then + export FLASHINFER_JIT_CACHE_REPORT_FILE +fi + +bash scripts/task_jit_run_tests_part${TEST_SHARD}.sh diff --git a/scripts/update_whl_index.py b/scripts/update_whl_index.py index 4902d0cc7f..474ec61ea9 100644 --- a/scripts/update_whl_index.py +++ b/scripts/update_whl_index.py @@ -1,17 +1,261 @@ +""" +Update wheel index for flashinfer packages. + +This script generates PEP 503 compatible simple repository index pages for: +- flashinfer-python (no CUDA suffix in version) +- flashinfer-cubin (no CUDA suffix in version) +- flashinfer-jit-cache (has CUDA suffix like +cu130) + +The index is organized by CUDA version for jit-cache, and flat for others. +""" + import hashlib import pathlib import re +import argparse +import sys +from typing import Optional + + +def get_cuda_version(wheel_name: str) -> Optional[str]: + """Extract CUDA version from wheel filename.""" + # Match patterns like +cu128, +cu129, +cu130 + match = re.search(r"\+cu(\d+)", wheel_name) + if match: + return match.group(1) + return None + + +def get_package_info(wheel_path: pathlib.Path) -> Optional[dict]: + """Extract package information from wheel filename.""" + wheel_name = wheel_path.name + + # Try flashinfer-python pattern + match = re.match(r"flashinfer_python-([0-9.]+(?:\.dev\d+)?)-", wheel_name) + if match: + version = match.group(1) + return { + "package": "flashinfer-python", + "version": version, + "cuda": None, + } + + # Try flashinfer-cubin pattern + match = re.match(r"flashinfer_cubin-([0-9.]+(?:\.dev\d+)?)-", wheel_name) + if match: + version = match.group(1) + return { + "package": "flashinfer-cubin", + "version": version, + "cuda": None, + } + + # Try flashinfer-jit-cache pattern (has CUDA suffix in version) + match = re.match(r"flashinfer_jit_cache-([0-9.]+(?:\.dev\d+)?\+cu\d+)-", wheel_name) + if match: + version = match.group(1) + cuda_ver = get_cuda_version(wheel_name) + return { + "package": "flashinfer-jit-cache", + "version": version, + "cuda": cuda_ver, + } + + return None + + +def compute_sha256(file_path: pathlib.Path) -> str: + """Compute SHA256 hash of a file.""" + with open(file_path, "rb") as f: + return hashlib.sha256(f.read()).hexdigest() + + +def generate_directory_index(directory: pathlib.Path): + """Generate index.html for a directory listing its subdirectories.""" + # Get all subdirectories + subdirs = sorted([d for d in directory.iterdir() if d.is_dir()]) + + if not subdirs: + return + + index_file = directory / "index.html" + + # Generate HTML for directory listing + with index_file.open("w") as f: + f.write("\n") + f.write("\n") + f.write(f"Index of {directory.name or 'root'}\n") + f.write("\n") + f.write(f"

Index of {directory.name or 'root'}

\n") + + for subdir in subdirs: + f.write(f'{subdir.name}/
\n') + + f.write("\n") + f.write("\n") + + +def update_parent_indices(leaf_dir: pathlib.Path, root_dir: pathlib.Path): + """Recursively update index.html for all parent directories.""" + current = leaf_dir.parent + + while current >= root_dir and current != current.parent: + generate_directory_index(current) + current = current.parent + + +def update_index( + dist_dir: str = "dist", + output_dir: str = "whl", + base_url: str = "https://github.com/flashinfer-ai/flashinfer/releases/download", + release_tag: Optional[str] = None, + nightly: bool = False, +): + """ + Update wheel index from dist directory. + + Args: + dist_dir: Directory containing wheel files + output_dir: Output directory for index files + base_url: Base URL for wheel downloads + release_tag: GitHub release tag (e.g., 'nightly' or 'v0.3.1') + nightly: If True, update index to whl/nightly subdirectory for nightly releases + """ + dist_path = pathlib.Path(dist_dir) + if not dist_path.exists(): + print(f"Error: dist directory '{dist_dir}' does not exist") + sys.exit(1) + + wheels = sorted(dist_path.glob("*.whl")) + if not wheels: + print(f"No wheel files found in '{dist_dir}'") + sys.exit(1) + + print(f"Found {len(wheels)} wheel file(s)") + + # Track all directories that need parent index updates + created_dirs = set() + + for wheel_path in wheels: + print(f"\nProcessing: {wheel_path.name}") + + # Extract package information + info = get_package_info(wheel_path) + if not info: + print(" āš ļø Skipping: Could not parse wheel filename") + continue + + # Compute SHA256 + sha256 = compute_sha256(wheel_path) + + # Determine index directory + package = info["package"] + cuda = info["cuda"] + + # Add nightly subdirectory if nightly flag is set + base_output = pathlib.Path(output_dir) + if nightly: + base_output = base_output / "nightly" + + if cuda: + # CUDA-specific index for jit-cache: whl/nightly/cu130/flashinfer-jit-cache/ + index_dir = base_output / f"cu{cuda}" / package + else: + # No CUDA version for python/cubin: whl/nightly/flashinfer-python/ + index_dir = base_output / package + + index_dir.mkdir(parents=True, exist_ok=True) + created_dirs.add(index_dir) + + # Construct download URL + tag = release_tag or f"v{info['version'].split('+')[0].split('.dev')[0]}" + download_url = f"{base_url}/{tag}/{wheel_path.name}#sha256={sha256}" + + # Update index.html + index_file = index_dir / "index.html" + + # Read existing links to avoid duplicates + links = set() + if index_file.exists(): + with index_file.open("r") as f: + content = f.read() + # Simple regex to extract the tags + links.update(re.findall(r'.*?
\n', content)) + + # Create and add new link + new_link = f'{wheel_path.name}
\n' + is_new = new_link not in links + if is_new: + links.add(new_link) + + # Write the complete, valid HTML file + with index_file.open("w") as f: + f.write("\n") + f.write("\n") + f.write(f"Links for {package}\n") + f.write("\n") + f.write(f"

Links for {package}

\n") + for link in sorted(list(links)): + f.write(link) + f.write("\n") + f.write("\n") + print(f" āœ… Added to index: {index_dir}/index.html") + else: + print(f" ā„¹ļø Already in index: {index_dir}/index.html") + + print(f" šŸ“¦ Package: {package}") + print(f" šŸ”– Version: {info['version']}") + if cuda: + print(f" šŸŽ® CUDA: cu{cuda}") + print(f" šŸ“ URL: {download_url}") + + # Update parent directory indices + print("\nšŸ“‚ Updating parent directory indices...") + root_output = pathlib.Path(output_dir) + for leaf_dir in created_dirs: + update_parent_indices(leaf_dir, root_output) + print(" āœ… Parent indices updated") + + +def main(): + parser = argparse.ArgumentParser( + description="Update wheel index for flashinfer packages" + ) + parser.add_argument( + "--dist-dir", + default="dist", + help="Directory containing wheel files (default: dist)", + ) + parser.add_argument( + "--output-dir", + default="whl", + help="Output directory for index files (default: whl)", + ) + parser.add_argument( + "--base-url", + default="https://github.com/flashinfer-ai/flashinfer/releases/download", + help="Base URL for wheel downloads", + ) + parser.add_argument( + "--release-tag", + help="GitHub release tag (e.g., 'nightly' or 'v0.3.1'). If not specified, will be derived from version.", + ) + parser.add_argument( + "--nightly", + action="store_true", + help="Update index to whl/nightly subdirectory for nightly releases", + ) + + args = parser.parse_args() + + update_index( + dist_dir=args.dist_dir, + output_dir=args.output_dir, + base_url=args.base_url, + release_tag=args.release_tag, + nightly=args.nightly, + ) + -for path in sorted(pathlib.Path("dist").glob("*.whl")): - with open(path, "rb") as f: - sha256 = hashlib.sha256(f.read()).hexdigest() - ver, cu, torch = re.findall( - r"flashinfer_python-([0-9.]+(?:\.post[0-9]+)?)\+cu(\d+)torch([0-9.]+)-", - path.name, - )[0] - index_dir = pathlib.Path(f"flashinfer-whl/cu{cu}/torch{torch}/flashinfer-python") - index_dir.mkdir(exist_ok=True) - base_url = "https://github.com/flashinfer-ai/flashinfer/releases/download" - full_url = f"{base_url}/v{ver}/{path.name}#sha256={sha256}" - with (index_dir / "index.html").open("a") as f: - f.write(f'{path.name}
\n') +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py deleted file mode 100755 index 3f021464e9..0000000000 --- a/setup.py +++ /dev/null @@ -1,67 +0,0 @@ -""" -Copyright (c) 2023 by FlashInfer team. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -from pathlib import Path -from typing import List, Mapping - -import setuptools - -root = Path(__file__).parent.resolve() - - -def write_if_different(path: Path, content: str) -> None: - if path.exists() and path.read_text() == content: - return - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(content) - - -def get_version(): - package_version = (root / "version.txt").read_text().strip() - return f"{package_version}" - - -def generate_build_meta() -> None: - build_meta_str = f"__version__ = {get_version()!r}\n" - write_if_different(root / "flashinfer" / "_build_meta.py", build_meta_str) - - -ext_modules: List[setuptools.Extension] = [] -cmdclass: Mapping[str, type[setuptools.Command]] = {} - - -def get_install_requires() -> List[str]: - """Read install requirements from requirements.txt.""" - requirements_file = root / "requirements.txt" - if not requirements_file.exists(): - return [] - return [ - line.strip() - for line in requirements_file.read_text().splitlines() - if line.strip() and not line.strip().startswith("#") - ] - - -install_requires = get_install_requires() -generate_build_meta() - - -setuptools.setup( - version=get_version(), - ext_modules=ext_modules, - cmdclass=cmdclass, - install_requires=install_requires, -) diff --git a/tests/attention/test_alibi.py b/tests/attention/test_alibi.py index 417be942b7..21114aea7c 100644 --- a/tests/attention/test_alibi.py +++ b/tests/attention/test_alibi.py @@ -23,9 +23,13 @@ ) import flashinfer +from flashinfer.utils import has_flashinfer_jit_cache -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=not has_flashinfer_jit_cache(), + scope="module", +) def warmup_jit(): flashinfer.jit.build_jit_specs( gen_decode_attention_modules( diff --git a/tests/attention/test_attention_sink.py b/tests/attention/test_attention_sink.py index ab3ddc6c4b..aeacae1da1 100644 --- a/tests/attention/test_attention_sink.py +++ b/tests/attention/test_attention_sink.py @@ -24,10 +24,13 @@ from flashinfer.jit.utils import filename_safe_dtype_map from flashinfer.jit.attention import gen_batch_prefill_attention_sink_module from flashinfer.jit.attention.variants import attention_sink_decl -from flashinfer.utils import is_sm90a_supported +from flashinfer.utils import has_flashinfer_jit_cache, is_sm90a_supported -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=not has_flashinfer_jit_cache(), + scope="module", +) def warmup_jit(): jit_specs = [] for dtype in [torch.float16, torch.bfloat16]: diff --git a/tests/attention/test_batch_attention.py b/tests/attention/test_batch_attention.py index 6df9bdc2a2..1a0532b479 100644 --- a/tests/attention/test_batch_attention.py +++ b/tests/attention/test_batch_attention.py @@ -23,10 +23,13 @@ gen_persistent_batch_attention_modules, gen_prefill_attention_modules, ) -from flashinfer.utils import get_compute_capability +from flashinfer.utils import get_compute_capability, has_flashinfer_jit_cache -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=not has_flashinfer_jit_cache(), + scope="module", +) def warmup_jit(): flashinfer.jit.build_jit_specs( gen_persistent_batch_attention_modules( diff --git a/tests/attention/test_batch_decode_kernels.py b/tests/attention/test_batch_decode_kernels.py index cd04c273c3..39e736306a 100644 --- a/tests/attention/test_batch_decode_kernels.py +++ b/tests/attention/test_batch_decode_kernels.py @@ -22,9 +22,13 @@ ) from functools import partial import flashinfer +from flashinfer.utils import has_flashinfer_jit_cache -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=not has_flashinfer_jit_cache(), + scope="module", +) def warmup_jit(): flashinfer.jit.build_jit_specs( gen_decode_attention_modules( diff --git a/tests/attention/test_batch_invariant_fa2.py b/tests/attention/test_batch_invariant_fa2.py index 39e7102349..ea7abeb2c7 100644 --- a/tests/attention/test_batch_invariant_fa2.py +++ b/tests/attention/test_batch_invariant_fa2.py @@ -22,9 +22,13 @@ ) import flashinfer +from flashinfer.utils import has_flashinfer_jit_cache -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=not has_flashinfer_jit_cache(), + scope="module", +) def warmup_jit(): flashinfer.jit.build_jit_specs( gen_decode_attention_modules( diff --git a/tests/attention/test_batch_prefill_kernels.py b/tests/attention/test_batch_prefill_kernels.py index 8c89ee94d0..f067a70c62 100644 --- a/tests/attention/test_batch_prefill_kernels.py +++ b/tests/attention/test_batch_prefill_kernels.py @@ -20,9 +20,13 @@ from tests.test_helpers.jit_utils import gen_prefill_attention_modules import flashinfer +from flashinfer.utils import has_flashinfer_jit_cache -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=not has_flashinfer_jit_cache(), + scope="module", +) def warmup_jit(): flashinfer.jit.build_jit_specs( gen_prefill_attention_modules( diff --git a/tests/attention/test_deepseek_mla.py b/tests/attention/test_deepseek_mla.py index 85cafc2d86..0976c4ff39 100644 --- a/tests/attention/test_deepseek_mla.py +++ b/tests/attention/test_deepseek_mla.py @@ -28,13 +28,17 @@ gen_single_prefill_module, ) from flashinfer.utils import ( + has_flashinfer_jit_cache, is_sm90a_supported, is_sm100a_supported, is_sm110a_supported, ) -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=not has_flashinfer_jit_cache(), + scope="module", +) def warmup_jit(): try: modules = [] diff --git a/tests/attention/test_logits_cap.py b/tests/attention/test_logits_cap.py index 5059f3764e..14791cac09 100644 --- a/tests/attention/test_logits_cap.py +++ b/tests/attention/test_logits_cap.py @@ -24,9 +24,13 @@ ) import flashinfer +from flashinfer.utils import has_flashinfer_jit_cache -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=not has_flashinfer_jit_cache(), + scope="module", +) def warmup_jit(): flashinfer.jit.build_jit_specs( gen_decode_attention_modules( diff --git a/tests/attention/test_non_contiguous_decode.py b/tests/attention/test_non_contiguous_decode.py index 198ecd3e9f..c27ac11e5d 100644 --- a/tests/attention/test_non_contiguous_decode.py +++ b/tests/attention/test_non_contiguous_decode.py @@ -6,9 +6,13 @@ ) import flashinfer +from flashinfer.utils import has_flashinfer_jit_cache -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=not has_flashinfer_jit_cache(), + scope="module", +) def warmup_jit(): flashinfer.jit.build_jit_specs( gen_decode_attention_modules( diff --git a/tests/attention/test_non_contiguous_prefill.py b/tests/attention/test_non_contiguous_prefill.py index 627ef3ca63..96ad4aef05 100644 --- a/tests/attention/test_non_contiguous_prefill.py +++ b/tests/attention/test_non_contiguous_prefill.py @@ -19,9 +19,13 @@ from tests.test_helpers.jit_utils import gen_prefill_attention_modules import flashinfer +from flashinfer.utils import has_flashinfer_jit_cache -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=not has_flashinfer_jit_cache(), + scope="module", +) def warmup_jit(): flashinfer.jit.build_jit_specs( gen_prefill_attention_modules( diff --git a/tests/attention/test_shared_prefix_kernels.py b/tests/attention/test_shared_prefix_kernels.py index fc25b8afc5..30aee0dc38 100644 --- a/tests/attention/test_shared_prefix_kernels.py +++ b/tests/attention/test_shared_prefix_kernels.py @@ -22,9 +22,13 @@ ) import flashinfer +from flashinfer.utils import has_flashinfer_jit_cache -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=not has_flashinfer_jit_cache(), + scope="module", +) def warmup_jit(): flashinfer.jit.build_jit_specs( gen_decode_attention_modules( diff --git a/tests/attention/test_sliding_window.py b/tests/attention/test_sliding_window.py index fa22610578..e29c984d66 100644 --- a/tests/attention/test_sliding_window.py +++ b/tests/attention/test_sliding_window.py @@ -22,9 +22,13 @@ ) import flashinfer +from flashinfer.utils import has_flashinfer_jit_cache -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=not has_flashinfer_jit_cache(), + scope="module", +) def warmup_jit(): flashinfer.jit.build_jit_specs( gen_decode_attention_modules( diff --git a/tests/attention/test_tensor_cores_decode.py b/tests/attention/test_tensor_cores_decode.py index c5bbd84d81..19db15a640 100644 --- a/tests/attention/test_tensor_cores_decode.py +++ b/tests/attention/test_tensor_cores_decode.py @@ -22,9 +22,13 @@ ) from functools import partial import flashinfer +from flashinfer.utils import has_flashinfer_jit_cache -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=not has_flashinfer_jit_cache(), + scope="module", +) def warmup_jit(): flashinfer.jit.build_jit_specs( gen_decode_attention_modules( diff --git a/tests/conftest.py b/tests/conftest.py index 3e6694d1d8..dc81dc0db2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,8 @@ +import json import os import types -from typing import Any, Dict +from pathlib import Path +from typing import Any, Dict, Set import pytest import torch @@ -8,6 +10,14 @@ from torch.torch_version import __version__ as torch_version import flashinfer +from flashinfer.jit import MissingJITCacheError + +# Global tracking for JIT cache coverage +# Store tuples of (test_name, module_name, spec_info) +_MISSING_JIT_CACHE_MODULES: Set[tuple] = set() + +# File path for aggregating JIT cache info across multiple pytest runs +_JIT_CACHE_REPORT_FILE = os.environ.get("FLASHINFER_JIT_CACHE_REPORT_FILE", None) TORCH_COMPILE_FNS = [ flashinfer.activation.silu_and_mul, @@ -129,11 +139,92 @@ def is_cuda_oom_error_str(e: str) -> bool: @pytest.hookimpl(tryfirst=True) def pytest_runtest_call(item): - # skip OOM error + # skip OOM error and missing JIT cache errors try: item.runtest() except (torch.cuda.OutOfMemoryError, RuntimeError) as e: if isinstance(e, torch.cuda.OutOfMemoryError) or is_cuda_oom_error_str(str(e)): pytest.skip("Skipping due to OOM") + elif isinstance(e, MissingJITCacheError): + # Record the test that was skipped due to missing JIT cache + test_name = item.nodeid + spec = e.spec + module_name = spec.name if spec else "unknown" + + # Create a dict with module info for reporting + spec_info = None + if spec: + spec_info = { + "name": spec.name, + "sources": [str(s) for s in spec.sources], + "needs_device_linking": spec.needs_device_linking, + "aot_path": str(spec.aot_path), + } + + _MISSING_JIT_CACHE_MODULES.add((test_name, module_name, str(spec_info))) + pytest.skip(f"Skipping due to missing JIT cache for module: {module_name}") else: raise + + +def pytest_terminal_summary(terminalreporter, exitstatus, config): + """Generate JIT cache coverage report at the end of test session""" + if not _MISSING_JIT_CACHE_MODULES: + return # No missing modules + + # If report file is specified, write to file for later aggregation + # Otherwise, print summary directly + if _JIT_CACHE_REPORT_FILE: + from filelock import FileLock + + # Convert set to list for JSON serialization + data = [ + {"test_name": test_name, "module_name": module_name, "spec_info": spec_info} + for test_name, module_name, spec_info in _MISSING_JIT_CACHE_MODULES + ] + + # Use file locking to handle concurrent writes from multiple pytest processes + Path(_JIT_CACHE_REPORT_FILE).parent.mkdir(parents=True, exist_ok=True) + lock_file = _JIT_CACHE_REPORT_FILE + ".lock" + with FileLock(lock_file), open(_JIT_CACHE_REPORT_FILE, "a") as f: + for entry in data: + f.write(json.dumps(entry) + "\n") + return + + # Single pytest run - print summary directly + terminalreporter.section("flashinfer-jit-cache Package Coverage Report") + terminalreporter.write_line("") + terminalreporter.write_line( + "This report shows the coverage of the flashinfer-jit-cache package." + ) + terminalreporter.write_line( + "Tests are skipped when required modules are not found in the installed JIT cache." + ) + terminalreporter.write_line("") + terminalreporter.write_line( + f"āš ļø {len(_MISSING_JIT_CACHE_MODULES)} test(s) skipped due to missing JIT cache modules:" + ) + terminalreporter.write_line("") + + # Group by module name + module_to_tests = {} + for test_name, module_name, spec_info in _MISSING_JIT_CACHE_MODULES: + if module_name not in module_to_tests: + module_to_tests[module_name] = {"tests": [], "spec_info": spec_info} + module_to_tests[module_name]["tests"].append(test_name) + + for module_name in sorted(module_to_tests.keys()): + info = module_to_tests[module_name] + terminalreporter.write_line(f"Module: {module_name}") + terminalreporter.write_line(f" Spec: {info['spec_info']}") + terminalreporter.write_line(f" Affected tests ({len(info['tests'])}):") + for test in sorted(info["tests"]): + terminalreporter.write_line(f" - {test}") + terminalreporter.write_line("") + + terminalreporter.write_line( + "These tests require JIT compilation but FLASHINFER_DISABLE_JIT=1 was set." + ) + terminalreporter.write_line( + "To improve coverage, add the missing modules to the flashinfer-jit-cache build configuration." + ) diff --git a/tests/gemm/test_group_gemm.py b/tests/gemm/test_group_gemm.py index 7e0a02c610..fbdd9e26e4 100644 --- a/tests/gemm/test_group_gemm.py +++ b/tests/gemm/test_group_gemm.py @@ -18,13 +18,20 @@ import torch import flashinfer -from flashinfer.utils import determine_gemm_backend, is_sm90a_supported +from flashinfer.utils import ( + determine_gemm_backend, + has_flashinfer_jit_cache, + is_sm90a_supported, +) DTYPES = [torch.float16] CUDA_DEVICES = ["cuda:0"] -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=not has_flashinfer_jit_cache(), + scope="module", +) def warmup_jit(): jit_specs = [flashinfer.gemm.gen_gemm_module()] if is_sm90a_supported(torch.device("cuda:0")): diff --git a/tests/utils/test_activation.py b/tests/utils/test_activation.py index 3854d7f576..3a81681592 100644 --- a/tests/utils/test_activation.py +++ b/tests/utils/test_activation.py @@ -18,10 +18,13 @@ import torch import flashinfer -from flashinfer.utils import get_compute_capability +from flashinfer.utils import get_compute_capability, has_flashinfer_jit_cache -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=not has_flashinfer_jit_cache(), + scope="module", +) def warmup_jit(): flashinfer.jit.build_jit_specs( [ diff --git a/tests/utils/test_block_sparse.py b/tests/utils/test_block_sparse.py index 716e738db7..46052d18a3 100644 --- a/tests/utils/test_block_sparse.py +++ b/tests/utils/test_block_sparse.py @@ -24,9 +24,13 @@ ) import flashinfer +from flashinfer.utils import has_flashinfer_jit_cache -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=not has_flashinfer_jit_cache(), + scope="module", +) def warmup_jit(): flashinfer.jit.build_jit_specs( gen_decode_attention_modules( diff --git a/tests/utils/test_pod_kernels.py b/tests/utils/test_pod_kernels.py index 553e6c3a8c..8900cc1b6c 100644 --- a/tests/utils/test_pod_kernels.py +++ b/tests/utils/test_pod_kernels.py @@ -23,9 +23,13 @@ import flashinfer from flashinfer.jit.attention import gen_pod_module +from flashinfer.utils import has_flashinfer_jit_cache -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture( + autouse=not has_flashinfer_jit_cache(), + scope="module", +) def warmup_jit(): flashinfer.jit.build_jit_specs( gen_decode_attention_modules( diff --git a/version.txt b/version.txt index 9e11b32fca..1d0ba9ea18 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.3.1 +0.4.0